|
--- |
|
license: apache-2.0 |
|
datasets: |
|
- gsm8k |
|
- ChilleD/SVAMP |
|
- EleutherAI/asdiv |
|
metrics: |
|
- accuracy |
|
--- |
|
# Model Card for Model ID |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
We use Ensemble Thoughts Distillation to distill mathematical reasoning ability from gpt-3.5-turbo to CodeT5+-770m-py. |
|
|
|
### Model Description |
|
|
|
<!-- Provide a longer summary of what this model is. --> |
|
|
|
|
|
|
|
- **Developed by:** Xunyu Zhu |
|
- **Model type:** encoder-decoder |
|
- **Language(s) (NLP):** python |
|
- **License:** apache-2.0 |
|
- **Finetuned from model:** [Salesforce/codet5p-770m-py](https://huggingface.co/Salesforce/codet5p-770m-py) |
|
|
|
|
|
|
|
## Uses |
|
|
|
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. --> |
|
|
|
## Direct Use |
|
|
|
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. --> |
|
|
|
This model can be easily loaded using the AutoModelForSeq2SeqLM functionality and employs the same tokenizer as original [Salesforce/codet5p-770m-py](https://huggingface.co/Salesforce/codet5p-770m-py). |
|
|
|
When given a question, the prompt "Let’s break down the code step by step" is needed to add as the input to instruct the model to generate program in PoT. |
|
|
|
When given a question, the prompt "Let's think step by step." is needed to add as the input to instruct the model to generate rationale in CoT. |
|
|
|
When given a question, the prompt "System of linear equations: (Do not simplify)" is needed to add as the input to instruct the model to generate equations in EoT. |
|
|
|
### PoT |
|
|
|
```python |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
checkpoint = "zhuxunyu/etd-codet5p-770m-py" |
|
device = "cuda" # for GPU usage or "cpu" for CPU usage |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device) |
|
question = "Question: Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nLet’s break down the code step by step\n". |
|
input = tokenizer(question, max_length=256, padding="max_length", truncation=True, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
output = model.generate(**input, max_length=256) |
|
generation = tokenizer.decode(output, skip_special_tokens=True) |
|
``` |
|
|
|
### CoT |
|
|
|
```python |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
checkpoint = "zhuxunyu/etd-codet5p-770m-py" |
|
device = "cuda" # for GPU usage or "cpu" for CPU usage |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device) |
|
question = "Question: Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nLet's think step by step.\n". |
|
input = tokenizer(question, max_length=256, padding="max_length", truncation=True, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
output = model.generate(**input, max_length=256) |
|
generation = tokenizer.decode(output, skip_special_tokens=True) |
|
``` |
|
### EoT |
|
|
|
```python |
|
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer |
|
checkpoint = "zhuxunyu/etd-codet5p-770m-py" |
|
device = "cuda" # for GPU usage or "cpu" for CPU usage |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device) |
|
question = "Question: Janet\u2019s ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?\nSystem of linear equations: (Do not simplify)\n". |
|
input = tokenizer(question, max_length=256, padding="max_length", truncation=True, return_tensors="pt").to(model.device) |
|
with torch.no_grad(): |
|
output = model.generate(**input, max_length=256) |
|
generation = tokenizer.decode(output, skip_special_tokens=True) |
|
``` |
|
|
|
|
|
## Training Details |
|
|
|
### Training Data |
|
|
|
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. --> |
|
|
|
We prompt gpt-3.5-turbo to generate reasoning programs to solve questions in GSM8K training dataset, and each question includes 4 reasoning programs, 4 reasoning rationales, 4 reasoning equations systems. Then, questions in GSM8K training dataset and |
|
their corresponding reasoning processes are built as a training dataset, and we use the training dataset to fine-tune the LM. |
|
|
|
|
|
|
|
## Evaluation |
|
|
|
<!-- This section describes the evaluation protocols and provides the results. --> |
|
|
|
|
|
### Results |
|
|
|
| Dataset | GSM8K | ASDiv | SVAMP | MultiArith | |
|
| :-----: | :---: | :---: | :---: | :--------: | |
|
| PoT | 50.34 | 55.2 | 51.6 | 88.33 | |
|
| EoT | 48.21 | 52.81 | 55.7 | 70.16 | |
|
| CoT | 25.47 | 29.67 | 23.3 | 46.5 | |
|
| Ensemble_all | 50.56 | 55.34 | 52.3 | 88.83 | |
|
|
|
|
|
|
|
|
|
|
|
## Citation |
|
|
|
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. --> |
|
|
|
**BibTeX:** |
|
|
|
``` |
|
@misc{zhu2024improving, |
|
title={Distilling Mathematical Reasoning Capabilities into Small Language Models}, |
|
author={Xunyu Zhu and Jian Li and Yong Liu and Can Ma and Weiping Wang}, |
|
year={2024}, |
|
eprint={2401.11864}, |
|
archivePrefix={arXiv}, |
|
primaryClass={cs.CL} |
|
} |
|
``` |