Zack Zhiyuan Li commited on
Commit
cead849
1 Parent(s): e6b1fe1
Files changed (1) hide show
  1. README.md +127 -33
README.md CHANGED
@@ -37,49 +37,129 @@ Octopus-V1, a series of advanced open-source language models with parameters ran
37
 
38
  You can run the model on a GPU using the following code.
39
  ```python
40
- from gemma.modeling_gemma import GemmaForCausalLM
41
- from transformers import AutoTokenizer
42
  import torch
43
- import time
44
-
45
- def inference(input_text):
46
- start_time = time.time()
47
- input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)
48
- input_length = input_ids["input_ids"].shape[1]
49
- outputs = model.generate(
50
- input_ids=input_ids["input_ids"],
51
- max_length=1024,
52
- do_sample=False)
53
- generated_sequence = outputs[:, input_length:].tolist()
54
- res = tokenizer.decode(generated_sequence[0])
55
- end_time = time.time()
56
- return {"output": res, "latency": end_time - start_time}
57
-
58
- model_id = "NexaAIDev/android_API_10k_data"
59
- tokenizer = AutoTokenizer.from_pretrained(model_id)
60
- model = GemmaForCausalLM.from_pretrained(
61
- model_id, torch_dtype=torch.bfloat16, device_map="auto"
62
- )
63
-
64
- input_text = "Take a selfie for me with front camera"
65
- nexa_query = f"Below is the query from the users, please call the correct function and generate the parameters to call the function.\n\nQuery: {input_text} \n\nResponse:"
66
- start_time = time.time()
67
- print("nexa model result:\n", inference(nexa_query))
68
- print("latency:", time.time() - start_time," s")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  ```
70
 
71
  ## Evaluation
72
- <p align="center" width="100%">
73
- <a><img src="latency_plot.jpg" alt="ondevice" style="width: 80%; min-width: 300px; display: block; margin: auto; margin-bottom: 20px;"></a>
74
- <a><img src="accuracy_plot.jpg" alt="ondevice" style="width: 80%; min-width: 300px; display: block; margin: auto;"></a>
75
- </p>
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  ## License
78
  This model was trained on commercially viable data and is under the [Nexa AI community disclaimer](https://www.nexa4ai.com/disclaimer).
79
 
80
 
81
  ## References
82
- We thank the Google Gemma team for their amazing models!
83
  ```
84
  @misc{gemma-2023-open-models,
85
  author = {{Gemma Team, Google DeepMind}},
@@ -87,6 +167,20 @@ We thank the Google Gemma team for their amazing models!
87
  url = {https://goo.gle/GemmaReport},
88
  year = {2023},
89
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  ```
91
 
92
  ## Citation
 
37
 
38
  You can run the model on a GPU using the following code.
39
  ```python
 
 
40
  import torch
41
+ from transformers import AutoTokenizer, AutoModelForCausalLM
42
+ import torch.nn.functional as F
43
+
44
+ prompt = """You are an assistant, and you need to call find appropriate functions according to the query of the users. Firstly, find the relevant functions, then get the function arguments by understanding the user's query. The following functions are available for you to fetch further data to answer user questions:
45
+
46
+ Function:
47
+ def basketapi_league_seasons(tournamentId):
48
+ '''
49
+ Get access to historical and current seasons for a specific basketball league using the tournament ID.
50
+ Args:
51
+ tournamentId (number): The argument tournamentId is a number that represents the identifier of a tournament in the context of the function.
52
+ '''
53
+
54
+ def os_sports_goal_distributions(unique_tournament_id,season_id,team_id):
55
+ '''
56
+ Get goal distributions by team, tournament ID, and season ID for in-depth sports performance analysis.
57
+ Args:
58
+ unique_tournament_id (number): The unique_tournament_id argument is a number representing the unique identifier for a tournament.
59
+ season_id (number): The argument season_id is a number that represents the identifier of the season for the search query string.
60
+ team_id (number): The team_id argument represents the teams identification number.
61
+ '''
62
+
63
+ def transfermarket_get_table(id,seasonID,domain,homeAway):
64
+ '''
65
+ Get tables by competition and season from transfermarket platform for comprehensive and detailed competition and season-related data.
66
+ Args:
67
+ id (string): The function argument "id" is a string representing an identifier.
68
+ seasonID (string): The seasonID argument is a string that represents the identifier for a specific season.
69
+ domain (string): The domain argument is a string that represents a search query.
70
+ homeAway (string): The homeAway argument is a string that represents the home or away status for a sports event.
71
+ '''
72
+
73
+ def no_relevant_function():
74
+ '''
75
+ Call this when no other provided function can be called to answer the user query.
76
+ '''
77
+
78
+ def soccersapi_stage_id(t,id):
79
+ '''
80
+ Get stage ID for a soccer match or event, access specific details like schedules, teams, and relevant data.
81
+ Args:
82
+ t (string): The argument "t" of type string represents the search query string.
83
+ id (number): This function argument is an identifier represented by a number, typically used to uniquely reference a specific entity within the system.
84
+ '''
85
+
86
+ Request the complete season data for a recently established basketball league using the tournament ID 309, aiming to analyze its inaugural seasons.
87
+ Response:
88
+ """
89
+
90
+ class NexaGenerator:
91
+ def __init__(self, model_id: AutoModelForCausalLM, tokenizer_id: AutoTokenizer):
92
+ self.model = AutoModelForCausalLM.from_pretrained(
93
+ model_id, torch_dtype=torch.bfloat16, device_map="auto"
94
+ )
95
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
96
+ self.eos_token_id = self.tokenizer.eos_token_id
97
+ self.token2id = self.tokenizer.get_vocab()
98
+
99
+ def deterministic_generate_next_token(
100
+ self,
101
+ input_ids: torch.Tensor, # shape: (1, seq_len), no support for batch yet
102
+ add_conditional_mask: bool = False,
103
+ usable_token_ids: torch.tensor = None, # element is token id
104
+ ) -> torch.tensor:
105
+ if add_conditional_mask:
106
+ assert usable_token_ids is not None, "usable_token_ids is required"
107
+ next_logits = self.model(input_ids)["logits"][:, -1:]
108
+ if add_conditional_mask:
109
+ mask = torch.full_like(next_logits, float("-inf"))
110
+ mask.scatter_(
111
+ 2,
112
+ usable_token_ids.unsqueeze(0).unsqueeze(0),
113
+ next_logits.gather(2, usable_token_ids.unsqueeze(0).unsqueeze(0)),
114
+ )
115
+ next_token_id = torch.argmax(mask, dim=-1)
116
+ else:
117
+ next_token_id = torch.argmax(next_logits, dim=-1)
118
+ return next_token_id
119
+
120
+ nexa_generator = NexaGenerator(model_id="NexaAIDev/Octopus-v1", tokenizer_id="NexaAIDev/Octopus-v1")
121
+
122
+ def get_response(prompt):
123
+ input_ids = nexa_generator.tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
124
+ for _ in range(200):
125
+ next_token_id = nexa_generator.deterministic_generate_next_token(
126
+ input_ids=input_ids,
127
+ add_conditional_mask=False,
128
+ usable_token_ids=None,
129
+ )
130
+ input_ids = torch.cat([input_ids, next_token_id], dim=-1)
131
+ if next_token_id[0].item() == nexa_generator.eos_token_id:
132
+ break
133
+ generated_text = nexa_generator.tokenizer.batch_decode(input_ids)
134
+ return generated_text[0]
135
+
136
+ print(get_response(prompt))
137
  ```
138
 
139
  ## Evaluation
140
+ <div style="width: 100%;">
141
+ <div style="text-align: left;">
142
+ Comparison of accuracy between the GPT-3.5 and GPT-4 models, alongside our pretrained models in the Octopus series.
143
+ </div>
144
+ <div style="text-align: center;">
145
+ <img src="accuracy_no_mask.jpg" alt="Accuracy without Conditional Mask" style="width: 80%; min-width: 300px; margin-bottom: 20px;">
146
+ </div>
147
+
148
+ <div style="text-align: left;">
149
+ Comparison of accuracy following the introduction of a conditional mask in the Octopus series models.
150
+ </div>
151
+ <div style="text-align: center;">
152
+ <img src="accuracy_with_mask.jpg" alt="Accuracy with Conditional Mask" style="width: 80%; min-width: 300px;">
153
+ </div>
154
+ </div>
155
+
156
 
157
  ## License
158
  This model was trained on commercially viable data and is under the [Nexa AI community disclaimer](https://www.nexa4ai.com/disclaimer).
159
 
160
 
161
  ## References
162
+ We thank the Meta llama2 team, Google Gemma team, Stability AI's Stable Code team for their amazing models!
163
  ```
164
  @misc{gemma-2023-open-models,
165
  author = {{Gemma Team, Google DeepMind}},
 
167
  url = {https://goo.gle/GemmaReport},
168
  year = {2023},
169
  }
170
+
171
+ @article{touvron2023llama,
172
+ title={Llama 2: Open foundation and fine-tuned chat models},
173
+ author={Touvron, Hugo and Martin, Louis and Stone, Kevin and Albert, Peter and Almahairi, Amjad and Babaei, Yasmine and Bashlykov, Nikolay and Batra, Soumya and Bhargava, Prajjwal and Bhosale, Shruti and others},
174
+ journal={arXiv preprint arXiv:2307.09288},
175
+ year={2023}
176
+ }
177
+
178
+ @misc{stable-code-3b,
179
+ author = {Pinnaparaju, Nikhil and Adithyan, Reshinth and Phung, Duy and Tow, Jonathan and Baicoianu, James and Cooper, Nathan},
180
+ title = {Stable Code 3B},
181
+ url = {https://huggingface.co/stabilityai/stable-code-3b},
182
+ year = {2023}
183
+ }
184
  ```
185
 
186
  ## Citation