File size: 9,075 Bytes
c11c4b2
 
 
 
 
 
 
 
 
 
e6b1fe1
c0f7db1
21f6fbf
c11c4b2
 
 
 
 
 
 
 
 
 
e6b1fe1
 
c11c4b2
e6b1fe1
c11c4b2
e6b1fe1
c11c4b2
e6b1fe1
c11c4b2
 
 
 
 
 
 
 
 
cead849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c11c4b2
c0f7db1
 
cead849
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c0f7db1
c11c4b2
 
c0f7db1
 
c11c4b2
cead849
c11c4b2
 
 
 
 
 
 
cead849
 
 
 
 
 
 
 
 
 
 
 
 
 
c11c4b2
c0f7db1
c11c4b2
 
 
 
c0f7db1
c11c4b2
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
---
license: apache-2.0
language:
- en
tags:
- gemma
- function calling
- on-device language model
- android
- conversational
inference: false
---

# Octopus V1: On-device language model for function calling of software APIs
<p align="center">
- <a href="https://www.nexa4ai.com/" target="_blank">Nexa AI Product</a>
- <a href="https://nexaai.github.io/octopus" target="_blank">ArXiv</a>
</p>

<p align="center" width="100%">
<a><img src="Octopus-logo.jpeg" alt="nexa-octopus" style="width: 40%; min-width: 300px; display: block; margin: auto;"></a>
</p>

## Introducing Octopus-V1
Octopus-V1, a series of advanced open-source language models with parameters ranging from 2B to 7B, represents Nexa AI's breakthrough in AI-driven software API interactions. Developed through meticulous fine-tuning using a specialized dataset from 30k+ RapidHub APIs, Octopus-V1 excels in understanding API structures and syntax. The models leverage conditional masking techniques to ensure precise, format-compliant API calls without compromising inference speed. A novel benchmark introduced alongside Octopus-V1 assesses its superior performance against GPT-4 in software API usage, signifying a leap forward in automating software development and API integration.

📱 **Support 30k+ APIs from RapidAPI Hub**: Octopus leverages an extensive dataset derived from over 30,000 popular APIs on RapidAPI Hub. This rich dataset ensures broad coverage and understanding of diverse software API interactions, enhancing the model's utility across various applications.

🐙 **Accuracy**: Fine-tuning on models with 2B, 3B, and 7B parameters yields Octopus, which surpasses GPT-4 in API call accuracy. The introduction of a conditional mask further refines its precision, making Octopus highly reliable for software API interactions.

🎯 **Conditional Masking**: A novel conditional masking technique is employed to ensure outputs adhere to the desired formats and reduce errors. This approach not only maintains fast inference speeds but also substantially increases the model's accuracy in generating function calls and parameters.

## Example Use Cases
<p align="center" width="100%">
<a><img src="tool-usage-compressed.png" alt="ondevice" style="width: 80%; min-width: 300px; display: block; margin: auto;"></a>
</p>

You can run the model on a GPU using the following code. 
```python
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch.nn.functional as F

prompt = """You are an assistant, and you need to call find appropriate functions according to the query of the users. Firstly, find the relevant functions, then get the function arguments by understanding the user's query. The following functions are available for you to fetch further data to answer user questions: 

Function: 
def basketapi_league_seasons(tournamentId): 
    '''
    Get access to historical and current seasons for a specific basketball league using the tournament ID. 
    Args: 
        tournamentId (number): The argument tournamentId is a number that represents the identifier of a tournament in the context of the function. 
    ''' 

def os_sports_goal_distributions(unique_tournament_id,season_id,team_id): 
    '''
    Get goal distributions by team, tournament ID, and season ID for in-depth sports performance analysis. 
    Args: 
        unique_tournament_id (number): The unique_tournament_id argument is a number representing the unique identifier for a tournament. 
        season_id (number): The argument season_id is a number that represents the identifier of the season for the search query string. 
        team_id (number): The team_id argument represents the teams identification number. 
    '''

def transfermarket_get_table(id,seasonID,domain,homeAway):
    '''
    Get tables by competition and season from transfermarket platform for comprehensive and detailed competition and season-related data. 
    Args: 
        id (string): The function argument "id" is a string representing an identifier. 
        seasonID (string): The seasonID argument is a string that represents the identifier for a specific season. 
        domain (string): The domain argument is a string that represents a search query. 
        homeAway (string): The homeAway argument is a string that represents the home or away status for a sports event. 
    '''

def no_relevant_function(): 
    ''' 
    Call this when no other provided function can be called to answer the user query. 
    '''

def soccersapi_stage_id(t,id):
    '''
    Get stage ID for a soccer match or event, access specific details like schedules, teams, and relevant data.
    Args: 
        t (string): The argument "t" of type string represents the search query string. 
        id (number): This function argument is an identifier represented by a number, typically used to uniquely reference a specific entity within the system. 
    ''' 

Request the complete season data for a recently established basketball league using the tournament ID 309, aiming to analyze its inaugural seasons. 
Response:
"""

class NexaGenerator:
    def __init__(self, model_id: AutoModelForCausalLM, tokenizer_id: AutoTokenizer):
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id, torch_dtype=torch.bfloat16, device_map="auto"
        )
        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
        self.eos_token_id = self.tokenizer.eos_token_id
        self.token2id = self.tokenizer.get_vocab()

    def deterministic_generate_next_token(
        self,
        input_ids: torch.Tensor,  # shape: (1, seq_len), no support for batch yet
        add_conditional_mask: bool = False,
        usable_token_ids: torch.tensor = None,  # element is token id
    ) -> torch.tensor:
        if add_conditional_mask:
            assert usable_token_ids is not None, "usable_token_ids is required"
        next_logits = self.model(input_ids)["logits"][:, -1:]
        if add_conditional_mask:
            mask = torch.full_like(next_logits, float("-inf"))
            mask.scatter_(
                2,
                usable_token_ids.unsqueeze(0).unsqueeze(0),
                next_logits.gather(2, usable_token_ids.unsqueeze(0).unsqueeze(0)),
            )
            next_token_id = torch.argmax(mask, dim=-1)
        else:
            next_token_id = torch.argmax(next_logits, dim=-1)
        return next_token_id

nexa_generator = NexaGenerator(model_id="NexaAIDev/Octopus-v1", tokenizer_id="NexaAIDev/Octopus-v1")

def get_response(prompt):
    input_ids = nexa_generator.tokenizer(prompt, return_tensors="pt")["input_ids"].to("cuda")
    for _ in range(200):
        next_token_id = nexa_generator.deterministic_generate_next_token(
            input_ids=input_ids,
            add_conditional_mask=False,
            usable_token_ids=None,
        )
        input_ids = torch.cat([input_ids, next_token_id], dim=-1)
        if next_token_id[0].item() == nexa_generator.eos_token_id:
            break
    generated_text = nexa_generator.tokenizer.batch_decode(input_ids)
    return generated_text[0]

print(get_response(prompt))
```

## Evaluation
<div style="width: 100%;">
  <div style="text-align: left;">
    Comparison of accuracy between the GPT-3.5 and GPT-4 models, alongside our pretrained models in the Octopus series.
  </div>
  <div style="text-align: center;">
    <img src="accuracy_no_mask.jpg" alt="Accuracy without Conditional Mask" style="width: 80%; min-width: 300px; margin-bottom: 20px;">
  </div>

  <div style="text-align: left;">
    Comparison of accuracy following the introduction of a conditional mask in the Octopus series models.
  </div>
  <div style="text-align: center;">
    <img src="accuracy_with_mask.jpg" alt="Accuracy with Conditional Mask" style="width: 80%; min-width: 300px;">
  </div>
</div>


## License
This model was trained on commercially viable data and is under the [Nexa AI community disclaimer](https://www.nexa4ai.com/disclaimer). 


## References
We thank the Meta llama2 team, Google Gemma team, Stability AI's Stable Code team for their amazing models!
```
@misc{gemma-2023-open-models,
  author = {{Gemma Team, Google DeepMind}},
  title = {Gemma: Open Models Based on Gemini Research and Technology},
  url = {https://goo.gle/GemmaReport},  
  year = {2023},
}

@article{touvron2023llama,
  title={Llama 2: Open foundation and fine-tuned chat models},
  author={Touvron, Hugo and Martin, Louis and Stone, Kevin and Albert, Peter and Almahairi, Amjad and Babaei, Yasmine and Bashlykov, Nikolay and Batra, Soumya and Bhargava, Prajjwal and Bhosale, Shruti and others},
  journal={arXiv preprint arXiv:2307.09288},
  year={2023}
}

@misc{stable-code-3b,
  author = {Pinnaparaju, Nikhil and Adithyan, Reshinth and Phung, Duy and Tow, Jonathan and Baicoianu, James and Cooper, Nathan},
  title = {Stable Code 3B},
  url = {https://huggingface.co/stabilityai/stable-code-3b},
  year = {2023}
}
```

## Citation
```
@misc{TODO}
```

## Contact
Please [contact us]([email protected]) to reach out for any issues and comments!