zuxin-llm commited on
Commit
b2c41bb
1 Parent(s): 0f2f7fe

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +429 -3
  2. example/multi_turn_xlam.ipynb +459 -0
README.md CHANGED
@@ -1,3 +1,429 @@
1
- ---
2
- license: cc-by-nc-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ extra_gated_heading: >-
3
+ Acknowledge to follow corresponding license to access the
4
+ repository
5
+ extra_gated_button_content: Agree and access repository
6
+ extra_gated_fields:
7
+ First Name: text
8
+ Last Name: text
9
+ Country: country
10
+ Affiliation: text
11
+ license: cc-by-nc-4.0
12
+ datasets:
13
+ - Salesforce/xlam-function-calling-60k
14
+ language:
15
+ - en
16
+ pipeline_tag: text-generation
17
+ tags:
18
+ - function-calling
19
+ - LLM Agent
20
+ - tool-use
21
+ - mistral
22
+ - pytorch
23
+ ---
24
+
25
+ <p align="center">
26
+ <img width="500px" alt="xLAM" src="https://huggingface.co/datasets/jianguozhang/logos/resolve/main/xlam-no-background.png">
27
+ </p>
28
+ <p align="center">
29
+ <a href="">[Homepage]</a> |
30
+ <a href="">[Paper]</a> |
31
+ <a href="https://github.com/SalesforceAIResearch/xLAM">[Github]</a>
32
+ </p>
33
+ <hr>
34
+
35
+
36
+ Welcome to the xLAM model family! [Large Action Models (LAMs)](https://blog.salesforceairesearch.com/large-action-models/) are advanced large language models designed to enhance decision-making and translate user intentions into executable actions that interact with the world. LAMs autonomously plan and execute tasks to achieve specific goals, serving as the brains of AI agents. They have the potential to automate workflow processes across various domains, making them invaluable for a wide range of applications.
37
+ **The model release is exclusively for research purposes. A new and enhanced version of xLAM will soon be available exclusively to customers on our Platform.**
38
+
39
+ ## Table of Contents
40
+ - [Model Series](#model-series)
41
+ - [Repository Overview](#repository-overview)
42
+ - [Benchmark Results](#benchmark-results)
43
+ - [Usage](#usage)
44
+ - [Basic Usage with Huggingface](#basic-usage-with-huggingface)
45
+ - [License](#license)
46
+ - [Citation](#citation)
47
+
48
+ ## Model Series
49
+
50
+ We provide a series of xLAMs in different sizes to cater to various applications, including those optimized for function-calling and general agent applications:
51
+
52
+ | Model | # Total Params | Context Length | Download Model | Download GGUF files |
53
+ |------------------------|----------------|----------------|----------------|----------|
54
+ | xLAM-1b-fc-r | 1.35B | 16k | [🤗 Link](https://huggingface.co/Salesforce/xLAM-1b-fc-r) | [🤗 Link](https://huggingface.co/Salesforce/xLAM-1b-fc-r-gguf) |
55
+ | xLAM-7b-fc-r | 6.91B | 4k | [🤗 Link](https://huggingface.co/Salesforce/xLAM-7b-fc-r) | [🤗 Link](https://huggingface.co/Salesforce/xLAM-7b-fc-r-gguf) |
56
+ | xLAM-7b-r | 7.24B | 32k | [🤗 Link](https://huggingface.co/Salesforce/xLAM-7b-r) | -- |
57
+ | xLAM-8x7b-r | 46.7B | 32k | [🤗 Link](https://huggingface.co/Salesforce/xLAM-8x7b-r) | -- |
58
+ | xLAM-8x22b-r | 141B | 64k | [🤗 Link](https://huggingface.co/Salesforce/xLAM-8x22b-r) | -- |
59
+
60
+
61
+
62
+
63
+
64
+
65
+ For our Function-calling series (more details are included at [here](https://huggingface.co/Salesforce/xLAM-7b-fc-r)), we also provide their quantized [GGUF](https://huggingface.co/docs/hub/en/gguf) files for efficient deployment and execution. GGUF is a file format designed to efficiently store and load large language models, making GGUF ideal for running AI models on local devices with limited resources, enabling offline functionality and enhanced privacy.
66
+
67
+ For more details, check our [GitHub](https://github.com/SalesforceAIResearch/xLAM) and [paper]().
68
+
69
+
70
+ ## Repository Overview
71
+
72
+ This repository is about the general tool use series. For more specialized function calling models, please take a look into our `fc` series [here](https://huggingface.co/Salesforce/xLAM-7b-fc-r).
73
+
74
+ The instructions will guide you through the setup, usage, and integration of our model series with HuggingFace.
75
+ ### Framework Versions
76
+
77
+ - Transformers 4.41.0
78
+ - Pytorch 2.3.0+cu121
79
+ - Datasets 2.19.1
80
+ - Tokenizers 0.19.1
81
+
82
+ ## Usage
83
+
84
+ ### Basic Usage with Huggingface
85
+
86
+ To use the model from Huggingface, please first install the `transformers` library:
87
+ ```bash
88
+ pip install transformers>=4.41.0
89
+ ```
90
+
91
+ Please note that, our model works best with our provided prompt format.
92
+ It allows us to extract JSON output that is similar to the [function-calling mode of ChatGPT](https://platform.openai.com/docs/guides/function-calling).
93
+
94
+ We use the following example to illustrate how to use our model for 1) single-turn use case, and 2) multi-turn use case
95
+
96
+ #### 1. Single-turn use case
97
+
98
+ ````python
99
+ import json
100
+ import torch
101
+ from transformers import AutoModelForCausalLM, AutoTokenizer
102
+
103
+ torch.random.manual_seed(0)
104
+
105
+ model_name = "Salesforce/xLAM-7b-r"
106
+ model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype="auto", trust_remote_code=True)
107
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
108
+
109
+ # Please use our provided instruction prompt for best performance
110
+ task_instruction = """
111
+ Based on the previous context and API request history, generate an API request or a response as an AI assistant.""".strip()
112
+
113
+ format_instruction = """
114
+ The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make
115
+ tool_calls an empty list "[]".
116
+ ```
117
+ {"thought": "the thought process, or an empty string", "tool_calls": [{"name": "api_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}]}
118
+ ```
119
+ """.strip()
120
+
121
+ # Define the input query and available tools
122
+ query = "What's the weather like in New York in fahrenheit?"
123
+
124
+ get_weather_api = {
125
+ "name": "get_weather",
126
+ "description": "Get the current weather for a location",
127
+ "parameters": {
128
+ "type": "object",
129
+ "properties": {
130
+ "location": {
131
+ "type": "string",
132
+ "description": "The city and state, e.g. San Francisco, New York"
133
+ },
134
+ "unit": {
135
+ "type": "string",
136
+ "enum": ["celsius", "fahrenheit"],
137
+ "description": "The unit of temperature to return"
138
+ }
139
+ },
140
+ "required": ["location"]
141
+ }
142
+ }
143
+
144
+ search_api = {
145
+ "name": "search",
146
+ "description": "Search for information on the internet",
147
+ "parameters": {
148
+ "type": "object",
149
+ "properties": {
150
+ "query": {
151
+ "type": "string",
152
+ "description": "The search query, e.g. 'latest news on AI'"
153
+ }
154
+ },
155
+ "required": ["query"]
156
+ }
157
+ }
158
+
159
+ openai_format_tools = [get_weather_api, search_api]
160
+
161
+ # Helper function to convert openai format tools to our more concise xLAM format
162
+ def convert_to_xlam_tool(tools):
163
+ ''''''
164
+ if isinstance(tools, dict):
165
+ return {
166
+ "name": tools["name"],
167
+ "description": tools["description"],
168
+ "parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
169
+ }
170
+ elif isinstance(tools, list):
171
+ return [convert_to_xlam_tool(tool) for tool in tools]
172
+ else:
173
+ return tools
174
+
175
+ def build_conversation_history_prompt(conversation_history: str):
176
+ parsed_history = []
177
+ for step_data in conversation_history:
178
+ parsed_history.append({
179
+ "step_id": step_data["step_id"],
180
+ "thought": step_data["thought"],
181
+ "tool_calls": step_data["tool_calls"],
182
+ "next_observation": step_data["next_observation"],
183
+ "user_input": step_data['user_input']
184
+ })
185
+
186
+ history_string = json.dumps(parsed_history)
187
+ return f"\n[BEGIN OF HISTORY STEPS]\n{history_string}\n[END OF HISTORY STEPS]\n"
188
+
189
+
190
+ # Helper function to build the input prompt for our model
191
+ def build_prompt(task_instruction: str, format_instruction: str, tools: list, query: str, conversation_history: list):
192
+ prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
193
+ prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(xlam_format_tools)}\n[END OF AVAILABLE TOOLS]\n\n"
194
+ prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
195
+ prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
196
+
197
+ if len(conversation_history) > 0: prompt += build_conversation_history_prompt(conversation_history)
198
+ return prompt
199
+
200
+ # Build the input and start the inference
201
+ xlam_format_tools = convert_to_xlam_tool(openai_format_tools)
202
+
203
+ conversation_history = []
204
+ content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query, conversation_history)
205
+
206
+ messages=[
207
+ { 'role': 'user', 'content': content}
208
+ ]
209
+
210
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
211
+
212
+ # tokenizer.eos_token_id is the id of <|EOT|> token
213
+ outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
214
+ agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
215
+ ````
216
+
217
+ Then you should be able to see the following output string in JSON format:
218
+
219
+ ```shell
220
+ {"thought": "I need to get the current weather for New York in fahrenheit.", "tool_calls": [{"name": "get_weather", "arguments": {"location": "New York", "unit": "fahrenheit"}}]}
221
+ ```
222
+
223
+ #### 2. Multi-turn use case
224
+
225
+ We also support multi-turn interaction with our model series. Here is the example of next round of interaction from the above example:
226
+
227
+ ````python
228
+ def parse_agent_action(agent_action: str):
229
+ """
230
+ Given an agent's action, parse it to add to conversation history
231
+ """
232
+ try: parsed_agent_action_json = json.loads(agent_action)
233
+ except: return "", []
234
+
235
+ if "thought" not in parsed_agent_action_json.keys(): thought = ""
236
+ else: thought = parsed_agent_action_json["thought"]
237
+
238
+ if "tool_calls" not in parsed_agent_action_json.keys(): tool_calls = []
239
+ else: tool_calls = parsed_agent_action_json["tool_calls"]
240
+
241
+ return thought, tool_calls
242
+
243
+ def update_conversation_history(conversation_history: list, agent_action: str, environment_response: str, user_input: str):
244
+ """
245
+ Update the conversation history list based on the new agent_action, environment_response, and/or user_input
246
+ """
247
+ thought, tool_calls = parse_agent_action(agent_action)
248
+ new_step_data = {
249
+ "step_id": len(conversation_history) + 1,
250
+ "thought": thought,
251
+ "tool_calls": tool_calls,
252
+ "step_id": len(conversation_history),
253
+ "next_observation": environment_response,
254
+ "user_input": user_input,
255
+ }
256
+
257
+ conversation_history.append(new_step_data)
258
+
259
+ def get_environment_response(agent_action: str):
260
+ """
261
+ Get the environment response for the agent_action
262
+ """
263
+ # TODO: add custom implementation here
264
+ error_message, response_message = "", ""
265
+ return {"error": error_message, "response": response_message}
266
+
267
+ # ------------- before here are the steps to get agent_response from the example above ----------
268
+
269
+ # 1. get the next state after agent's response:
270
+ # The next 2 lines are examples of getting environment response and user_input.
271
+ # It is depended on particular usage, we can have either one or both of those.
272
+ environment_response = get_environment_response(agent_action)
273
+ user_input = "Now, search on the Internet for cute puppies"
274
+
275
+ # 2. after we got environment_response and (or) user_input, we want to add to our conversation history
276
+ update_conversation_history(conversation_history, agent_action, environment_response, user_input)
277
+
278
+ # 3. we now can build the prompt
279
+ content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query, conversation_history)
280
+
281
+ # 4. Now, we just retrieve the inputs for the LLM
282
+ messages=[
283
+ { 'role': 'user', 'content': content}
284
+ ]
285
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
286
+
287
+ # 5. Generate the outputs & decode
288
+ # tokenizer.eos_token_id is the id of <|EOT|> token
289
+ outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
290
+ agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)
291
+ ````
292
+
293
+ This would be the corresponding output:
294
+ ```shell
295
+ {"thought": "I need to get the current weather for New York in fahrenheit.", "tool_calls": [{"name": "get_weather", "arguments": {"location": "New York", "unit": "fahrenheit"}}]}
296
+ ```
297
+
298
+ We highly recommend to use our provided prompt format and helper functions to yield the best function-calling performance of our model.
299
+
300
+ #### Example multi-turn prompt and output
301
+
302
+ Prompt:
303
+ ````json
304
+ [BEGIN OF TASK INSTRUCTION]
305
+ Based on the previous context and API request history, generate an API request or a response as an AI assistant.
306
+ [END OF TASK INSTRUCTION]
307
+
308
+ [BEGIN OF AVAILABLE TOOLS]
309
+ [
310
+ {
311
+ "name": "get_fire_info",
312
+ "description": "Query the latest wildfire information",
313
+ "parameters": {
314
+ "location": {
315
+ "type": "string",
316
+ "description": "Location of the wildfire, for example: 'California'",
317
+ "required": true,
318
+ "format": "free"
319
+ },
320
+ "radius": {
321
+ "type": "number",
322
+ "description": "The radius (in miles) around the location where the wildfire is occurring, for example: 10",
323
+ "required": false,
324
+ "format": "free"
325
+ }
326
+ }
327
+ },
328
+ {
329
+ "name": "get_hurricane_info",
330
+ "description": "Query the latest hurricane information",
331
+ "parameters": {
332
+ "name": {
333
+ "type": "string",
334
+ "description": "Name of the hurricane, for example: 'Irma'",
335
+ "required": true,
336
+ "format": "free"
337
+ }
338
+ }
339
+ },
340
+ {
341
+ "name": "get_earthquake_info",
342
+ "description": "Query the latest earthquake information",
343
+ "parameters": {
344
+ "magnitude": {
345
+ "type": "number",
346
+ "description": "The minimum magnitude of the earthquake that needs to be queried.",
347
+ "required": false,
348
+ "format": "free"
349
+ },
350
+ "location": {
351
+ "type": "string",
352
+ "description": "Location of the earthquake, for example: 'California'",
353
+ "required": false,
354
+ "format": "free"
355
+ }
356
+ }
357
+ }
358
+ ]
359
+ [END OF AVAILABLE TOOLS]
360
+
361
+ [BEGIN OF FORMAT INSTRUCTION]
362
+ Your output should be in the JSON format, which specifies a list of function calls. The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
363
+ ```{"thought": "the thought process, or an empty string", "tool_calls": [{"name": "api_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}]}```
364
+ [END OF FORMAT INSTRUCTION]
365
+
366
+ [BEGIN OF QUERY]
367
+ User: Can you give me the latest information on the wildfires occurring in California?
368
+ [END OF QUERY]
369
+
370
+ [BEGIN OF HISTORY STEPS]
371
+ [
372
+ {
373
+ "thought": "Sure, what is the radius (in miles) around the location of the wildfire?",
374
+ "tool_calls": [],
375
+ "step_id": 1,
376
+ "next_observation": "",
377
+ "user_input": "User: Let me think... 50 miles."
378
+ },
379
+ {
380
+ "thought": "",
381
+ "tool_calls": [
382
+ {
383
+ "name": "get_fire_info",
384
+ "arguments": {
385
+ "location": "California",
386
+ "radius": 50
387
+ }
388
+ }
389
+ ],
390
+ "step_id": 2,
391
+ "next_observation": [
392
+ {
393
+ "location": "Los Angeles",
394
+ "acres_burned": 1500,
395
+ "status": "contained"
396
+ },
397
+ {
398
+ "location": "San Diego",
399
+ "acres_burned": 12000,
400
+ "status": "active"
401
+ }
402
+ ]
403
+ },
404
+ {
405
+ "thought": "Based on the latest information, there are wildfires in Los Angeles and San Diego. The wildfire in Los Angeles has burned 1,500 acres and is contained, while the wildfire in San Diego has burned 12,000 acres and is still active.",
406
+ "tool_calls": [],
407
+ "step_id": 3,
408
+ "next_observation": "",
409
+ "user_input": "User: Can you tell me about the latest earthquake?"
410
+ }
411
+ ]
412
+
413
+ [END OF HISTORY STEPS]
414
+ ````
415
+
416
+ Output:
417
+ ````json
418
+ {"thought": "", "tool_calls": [{"name": "get_earthquake_info", "arguments": {"location": "California"}}]}
419
+ ````
420
+
421
+
422
+ ## License
423
+ The model is distributed under the CC-BY-NC-4.0 license.
424
+
425
+ <!-- ## Citation
426
+
427
+ If you find this repo helpful, please cite our paper:
428
+ ```bibtex
429
+ ``` -->
example/multi_turn_xlam.ipynb ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "ce4a9ccf-4bd6-43fb-a24d-b6a7da401a96",
6
+ "metadata": {},
7
+ "source": [
8
+ "## Load xLAM model"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "id": "b1351d81-4502-4b65-b88a-464acd0e80f8",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import torch \n",
19
+ "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
20
+ "torch.random.manual_seed(0) \n",
21
+ "\n",
22
+ "model_name = \"Salesforce/xLAM-7b-r\"\n",
23
+ "model = AutoModelForCausalLM.from_pretrained(model_name, device_map=\"auto\", torch_dtype=\"auto\", trust_remote_code=True)\n",
24
+ "tokenizer = AutoTokenizer.from_pretrained(model_name) "
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "markdown",
29
+ "id": "2cdd5bae-da43-4713-9956-360f1f3a9721",
30
+ "metadata": {},
31
+ "source": [
32
+ "## Build the prompt"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": 1,
38
+ "id": "e138e9f6-0543-427c-bce6-b4f14765a040",
39
+ "metadata": {
40
+ "tags": []
41
+ },
42
+ "outputs": [],
43
+ "source": [
44
+ "import json\n",
45
+ "\n",
46
+ "# Please use our provided instruction prompt for best performance\n",
47
+ "task_instruction = \"\"\"\n",
48
+ "Based on the previous context and API request history, generate an API request or a response as an AI assistant.\"\"\".strip()\n",
49
+ "\n",
50
+ "format_instruction = \"\"\"\n",
51
+ "The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make \n",
52
+ "tool_calls an empty list \"[]\".\n",
53
+ "```\n",
54
+ "{\"thought\": \"the thought process, or an empty string\", \"tool_calls\": [{\"name\": \"api_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}}]}\n",
55
+ "```\n",
56
+ "\"\"\".strip()\n",
57
+ "\n",
58
+ "get_weather_api = {\n",
59
+ " \"name\": \"get_weather\",\n",
60
+ " \"description\": \"Get the current weather for a location\",\n",
61
+ " \"parameters\": {\n",
62
+ " \"type\": \"object\",\n",
63
+ " \"properties\": {\n",
64
+ " \"location\": {\n",
65
+ " \"type\": \"string\",\n",
66
+ " \"description\": \"The city and state, e.g. San Francisco, New York\"\n",
67
+ " },\n",
68
+ " \"unit\": {\n",
69
+ " \"type\": \"string\",\n",
70
+ " \"enum\": [\"celsius\", \"fahrenheit\"],\n",
71
+ " \"description\": \"The unit of temperature to return\"\n",
72
+ " }\n",
73
+ " },\n",
74
+ " \"required\": [\"location\"]\n",
75
+ " }\n",
76
+ "}\n",
77
+ "\n",
78
+ "search_api = {\n",
79
+ " \"name\": \"search\",\n",
80
+ " \"description\": \"Search for information on the internet\",\n",
81
+ " \"parameters\": {\n",
82
+ " \"type\": \"object\",\n",
83
+ " \"properties\": {\n",
84
+ " \"query\": {\n",
85
+ " \"type\": \"string\",\n",
86
+ " \"description\": \"The search query, e.g. 'latest news on AI'\"\n",
87
+ " }\n",
88
+ " },\n",
89
+ " \"required\": [\"query\"]\n",
90
+ " }\n",
91
+ "}\n",
92
+ "\n",
93
+ "openai_format_tools = [get_weather_api, search_api]\n",
94
+ "\n",
95
+ "# Define the input query and available tools\n",
96
+ "query = \"What's the weather like in New York in fahrenheit?\"\n",
97
+ "\n",
98
+ "# Helper function to convert openai format tools to our more concise xLAM format\n",
99
+ "def convert_to_xlam_tool(tools):\n",
100
+ " ''''''\n",
101
+ " if isinstance(tools, dict):\n",
102
+ " return {\n",
103
+ " \"name\": tools[\"name\"],\n",
104
+ " \"description\": tools[\"description\"],\n",
105
+ " \"parameters\": {k: v for k, v in tools[\"parameters\"].get(\"properties\", {}).items()}\n",
106
+ " }\n",
107
+ " elif isinstance(tools, list):\n",
108
+ " return [convert_to_xlam_tool(tool) for tool in tools]\n",
109
+ " else:\n",
110
+ " return tools\n",
111
+ "\n",
112
+ "def build_conversation_history_prompt(conversation_history: str):\n",
113
+ " parsed_history = []\n",
114
+ " for step_data in conversation_history:\n",
115
+ " parsed_history.append({\n",
116
+ " \"step_id\": step_data[\"step_id\"],\n",
117
+ " \"thought\": step_data[\"thought\"],\n",
118
+ " \"tool_calls\": step_data[\"tool_calls\"],\n",
119
+ " \"next_observation\": step_data[\"next_observation\"],\n",
120
+ " \"user_input\": step_data['user_input']\n",
121
+ " })\n",
122
+ " \n",
123
+ " history_string = json.dumps(parsed_history)\n",
124
+ " return f\"\\n[BEGIN OF HISTORY STEPS]\\n{history_string}\\n[END OF HISTORY STEPS]\\n\"\n",
125
+ " \n",
126
+ " \n",
127
+ "# Helper function to build the input prompt for our model\n",
128
+ "def build_prompt(task_instruction: str, format_instruction: str, tools: list, query: str, conversation_history: list):\n",
129
+ " prompt = f\"[BEGIN OF TASK INSTRUCTION]\\n{task_instruction}\\n[END OF TASK INSTRUCTION]\\n\\n\"\n",
130
+ " prompt += f\"[BEGIN OF AVAILABLE TOOLS]\\n{json.dumps(xlam_format_tools)}\\n[END OF AVAILABLE TOOLS]\\n\\n\"\n",
131
+ " prompt += f\"[BEGIN OF FORMAT INSTRUCTION]\\n{format_instruction}\\n[END OF FORMAT INSTRUCTION]\\n\\n\"\n",
132
+ " prompt += f\"[BEGIN OF QUERY]\\n{query}\\n[END OF QUERY]\\n\\n\"\n",
133
+ " \n",
134
+ " if len(conversation_history) > 0: prompt += build_conversation_history_prompt(conversation_history)\n",
135
+ " return prompt\n",
136
+ "\n",
137
+ "\n",
138
+ " \n",
139
+ "# Build the input and start the inference\n",
140
+ "xlam_format_tools = convert_to_xlam_tool(openai_format_tools)\n",
141
+ "\n",
142
+ "conversation_history = []\n",
143
+ "content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query, conversation_history)\n",
144
+ "\n",
145
+ "messages=[\n",
146
+ " { 'role': 'user', 'content': content}\n",
147
+ "]\n"
148
+ ]
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "execution_count": 2,
153
+ "id": "ff7bccd5-fa04-4fbe-92b3-13f58914da4d",
154
+ "metadata": {
155
+ "tags": []
156
+ },
157
+ "outputs": [
158
+ {
159
+ "name": "stdout",
160
+ "output_type": "stream",
161
+ "text": [
162
+ "[BEGIN OF TASK INSTRUCTION]\n",
163
+ "Based on the previous context and API request history, generate an API request or a response as an AI assistant.\n",
164
+ "[END OF TASK INSTRUCTION]\n",
165
+ "\n",
166
+ "[BEGIN OF AVAILABLE TOOLS]\n",
167
+ "[{\"name\": \"get_weather\", \"description\": \"Get the current weather for a location\", \"parameters\": {\"location\": {\"type\": \"string\", \"description\": \"The city and state, e.g. San Francisco, New York\"}, \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"], \"description\": \"The unit of temperature to return\"}}}, {\"name\": \"search\", \"description\": \"Search for information on the internet\", \"parameters\": {\"query\": {\"type\": \"string\", \"description\": \"The search query, e.g. 'latest news on AI'\"}}}]\n",
168
+ "[END OF AVAILABLE TOOLS]\n",
169
+ "\n",
170
+ "[BEGIN OF FORMAT INSTRUCTION]\n",
171
+ "The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make \n",
172
+ "tool_calls an empty list \"[]\".\n",
173
+ "```\n",
174
+ "{\"thought\": \"the thought process, or an empty string\", \"tool_calls\": [{\"name\": \"api_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}}]}\n",
175
+ "```\n",
176
+ "[END OF FORMAT INSTRUCTION]\n",
177
+ "\n",
178
+ "[BEGIN OF QUERY]\n",
179
+ "What's the weather like in New York in fahrenheit?\n",
180
+ "[END OF QUERY]\n",
181
+ "\n",
182
+ "\n"
183
+ ]
184
+ }
185
+ ],
186
+ "source": [
187
+ "print(content)"
188
+ ]
189
+ },
190
+ {
191
+ "cell_type": "markdown",
192
+ "id": "a5fb0006-9f5d-4d79-a8cd-819bad627441",
193
+ "metadata": {},
194
+ "source": [
195
+ "## Get the model output (agent_action)"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "id": "cbe56588-c786-4913-9062-373a22a92e08",
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
206
+ "\n",
207
+ "# tokenizer.eos_token_id is the id of <|EOT|> token\n",
208
+ "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
209
+ "agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n"
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "markdown",
214
+ "id": "b20ed2ae-86f6-489b-ad54-fe7ea911667b",
215
+ "metadata": {},
216
+ "source": [
217
+ "For demo purpose, we use an example agent_action"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": 3,
223
+ "id": "ab20c084-44fa-403d-92a5-1b8ced72e9be",
224
+ "metadata": {
225
+ "tags": []
226
+ },
227
+ "outputs": [],
228
+ "source": [
229
+ "agent_action = \"\"\"{\"thought\": \"\", \"tool_calls\": [{\"name\": \"get_weather\", \"arguments\": {\"location\": \"New York\"}}]}\n",
230
+ "\"\"\".strip()"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "markdown",
235
+ "id": "1cd4d8e4-ee6b-499e-b75f-a48df7848a60",
236
+ "metadata": {},
237
+ "source": [
238
+ "### Add follow-up question"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "execution_count": 4,
244
+ "id": "825649ba-2691-43a2-b3d8-7baf8b66d46e",
245
+ "metadata": {},
246
+ "outputs": [],
247
+ "source": [
248
+ "def parse_agent_action(agent_action: str):\n",
249
+ " \"\"\"\n",
250
+ " Given an agent's action, parse it to add to conversation history\n",
251
+ " \"\"\"\n",
252
+ " try: parsed_agent_action_json = json.loads(agent_action)\n",
253
+ " except: return \"\", []\n",
254
+ " \n",
255
+ " if \"thought\" not in parsed_agent_action_json.keys(): thought = \"\"\n",
256
+ " else: thought = parsed_agent_action_json[\"thought\"]\n",
257
+ " \n",
258
+ " if \"tool_calls\" not in parsed_agent_action_json.keys(): tool_calls = []\n",
259
+ " else: tool_calls = parsed_agent_action_json[\"tool_calls\"]\n",
260
+ " \n",
261
+ " return thought, tool_calls\n",
262
+ "\n",
263
+ "def update_conversation_history(conversation_history: list, agent_action: str, environment_response: str, user_input: str):\n",
264
+ " \"\"\"\n",
265
+ " Update the conversation history list based on the new agent_action, environment_response, and/or user_input\n",
266
+ " \"\"\"\n",
267
+ " thought, tool_calls = parse_agent_action(agent_action)\n",
268
+ " new_step_data = {\n",
269
+ " \"step_id\": len(conversation_history) + 1,\n",
270
+ " \"thought\": thought,\n",
271
+ " \"tool_calls\": tool_calls,\n",
272
+ " \"next_observation\": environment_response,\n",
273
+ " \"user_input\": user_input,\n",
274
+ " }\n",
275
+ " \n",
276
+ " conversation_history.append(new_step_data)\n",
277
+ "\n",
278
+ "def get_environment_response(agent_action: str):\n",
279
+ " \"\"\"\n",
280
+ " Get the environment response for the agent_action\n",
281
+ " \"\"\"\n",
282
+ " # TODO: add custom implementation here\n",
283
+ " error_message, response_message = \"\", \"Sunny, 81 degrees\"\n",
284
+ " return {\"error\": error_message, \"response\": response_message}\n",
285
+ "\n"
286
+ ]
287
+ },
288
+ {
289
+ "cell_type": "markdown",
290
+ "id": "051e6aff-c21b-4dcb-9eb8-c34154d90c39",
291
+ "metadata": {},
292
+ "source": [
293
+ "1. **Get the next state after agent's response:**\n",
294
+ " The next 2 lines are examples of getting environment response and user_input.\n",
295
+ " It is depended on particular usage, we can have either one or both of those."
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 5,
301
+ "id": "649a8e9d-9757-408c-9214-0590556c2db4",
302
+ "metadata": {
303
+ "tags": []
304
+ },
305
+ "outputs": [],
306
+ "source": [
307
+ "environment_response = get_environment_response(agent_action)\n",
308
+ "user_input = \"Now, search on the Internet for cute puppies\""
309
+ ]
310
+ },
311
+ {
312
+ "cell_type": "markdown",
313
+ "id": "9c9c9418-1c54-4381-81d1-7f3834037739",
314
+ "metadata": {},
315
+ "source": [
316
+ "2. After we got environment_response and (or) user_input, we want to add to our conversation history"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "code",
321
+ "execution_count": 6,
322
+ "id": "bcfe89f3-8237-41bf-b92c-7c7568366042",
323
+ "metadata": {
324
+ "tags": []
325
+ },
326
+ "outputs": [
327
+ {
328
+ "data": {
329
+ "text/plain": [
330
+ "[{'step_id': 1,\n",
331
+ " 'thought': '',\n",
332
+ " 'tool_calls': [{'name': 'get_weather',\n",
333
+ " 'arguments': {'location': 'New York'}}],\n",
334
+ " 'next_observation': {'error': '', 'response': 'Sunny, 81 degrees'},\n",
335
+ " 'user_input': 'Now, search on the Internet for cute puppies'}]"
336
+ ]
337
+ },
338
+ "execution_count": 6,
339
+ "metadata": {},
340
+ "output_type": "execute_result"
341
+ }
342
+ ],
343
+ "source": [
344
+ "update_conversation_history(conversation_history, agent_action, environment_response, user_input)\n",
345
+ "conversation_history"
346
+ ]
347
+ },
348
+ {
349
+ "cell_type": "markdown",
350
+ "id": "23ba97c6-2356-49e8-a07b-0e664b7f505c",
351
+ "metadata": {},
352
+ "source": [
353
+ "3. We now can build the prompt with the updated history, and prepare the inputs for the LLM"
354
+ ]
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "execution_count": 7,
359
+ "id": "ed204b3a-3be5-431b-b355-facaf31309d2",
360
+ "metadata": {
361
+ "tags": []
362
+ },
363
+ "outputs": [],
364
+ "source": [
365
+ "content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query, conversation_history)\n",
366
+ "messages=[\n",
367
+ " { 'role': 'user', 'content': content}\n",
368
+ "]\n"
369
+ ]
370
+ },
371
+ {
372
+ "cell_type": "code",
373
+ "execution_count": 8,
374
+ "id": "8af843aa-6a47-4938-a455-567ea0cccce3",
375
+ "metadata": {
376
+ "tags": []
377
+ },
378
+ "outputs": [
379
+ {
380
+ "name": "stdout",
381
+ "output_type": "stream",
382
+ "text": [
383
+ "[BEGIN OF TASK INSTRUCTION]\n",
384
+ "Based on the previous context and API request history, generate an API request or a response as an AI assistant.\n",
385
+ "[END OF TASK INSTRUCTION]\n",
386
+ "\n",
387
+ "[BEGIN OF AVAILABLE TOOLS]\n",
388
+ "[{\"name\": \"get_weather\", \"description\": \"Get the current weather for a location\", \"parameters\": {\"location\": {\"type\": \"string\", \"description\": \"The city and state, e.g. San Francisco, New York\"}, \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"], \"description\": \"The unit of temperature to return\"}}}, {\"name\": \"search\", \"description\": \"Search for information on the internet\", \"parameters\": {\"query\": {\"type\": \"string\", \"description\": \"The search query, e.g. 'latest news on AI'\"}}}]\n",
389
+ "[END OF AVAILABLE TOOLS]\n",
390
+ "\n",
391
+ "[BEGIN OF FORMAT INSTRUCTION]\n",
392
+ "The output should be of the JSON format, which specifies a list of generated function calls. The example format is as follows, please make sure the parameter type is correct. If no function call is needed, please make \n",
393
+ "tool_calls an empty list \"[]\".\n",
394
+ "```\n",
395
+ "{\"thought\": \"the thought process, or an empty string\", \"tool_calls\": [{\"name\": \"api_name1\", \"arguments\": {\"argument1\": \"value1\", \"argument2\": \"value2\"}}]}\n",
396
+ "```\n",
397
+ "[END OF FORMAT INSTRUCTION]\n",
398
+ "\n",
399
+ "[BEGIN OF QUERY]\n",
400
+ "What's the weather like in New York in fahrenheit?\n",
401
+ "[END OF QUERY]\n",
402
+ "\n",
403
+ "\n",
404
+ "[BEGIN OF HISTORY STEPS]\n",
405
+ "[{\"step_id\": 1, \"thought\": \"\", \"tool_calls\": [{\"name\": \"get_weather\", \"arguments\": {\"location\": \"New York\"}}], \"next_observation\": {\"error\": \"\", \"response\": \"Sunny, 81 degrees\"}, \"user_input\": \"Now, search on the Internet for cute puppies\"}]\n",
406
+ "[END OF HISTORY STEPS]\n",
407
+ "\n"
408
+ ]
409
+ }
410
+ ],
411
+ "source": [
412
+ "print(content)"
413
+ ]
414
+ },
415
+ {
416
+ "cell_type": "markdown",
417
+ "id": "71f76a10-a152-49d7-aa6f-3060cc49b935",
418
+ "metadata": {},
419
+ "source": [
420
+ "## Get the model output for follow-up question"
421
+ ]
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "execution_count": null,
426
+ "id": "30af06fd-4aa7-4550-af39-3a77b5951882",
427
+ "metadata": {},
428
+ "outputs": [],
429
+ "source": [
430
+ "inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors=\"pt\").to(model.device)\n",
431
+ "# 5. Generate the outputs & decode\n",
432
+ "# tokenizer.eos_token_id is the id of <|EOT|> token\n",
433
+ "outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)\n",
434
+ "agent_action = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True)\n"
435
+ ]
436
+ }
437
+ ],
438
+ "metadata": {
439
+ "kernelspec": {
440
+ "display_name": "Python 3 (ipykernel) (Local)",
441
+ "language": "python",
442
+ "name": "python3"
443
+ },
444
+ "language_info": {
445
+ "codemirror_mode": {
446
+ "name": "ipython",
447
+ "version": 3
448
+ },
449
+ "file_extension": ".py",
450
+ "mimetype": "text/x-python",
451
+ "name": "python",
452
+ "nbconvert_exporter": "python",
453
+ "pygments_lexer": "ipython3",
454
+ "version": "3.10.13"
455
+ }
456
+ },
457
+ "nbformat": 4,
458
+ "nbformat_minor": 5
459
+ }