File size: 3,232 Bytes
d610bda 260e052 de20bc7 d610bda 260e052 d610bda 260e052 d610bda 260e052 d610bda 260e052 d610bda 260e052 d610bda 260e052 d610bda 260e052 d610bda 260e052 d610bda 260e052 d610bda 260e052 d610bda 70e4dd2 d610bda 260e052 d610bda 70e4dd2 260e052 d610bda 260e052 d610bda 260e052 d610bda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
---
library_name: peft
base_model: mistralai/Mistral-7B-Instruct-v0.2
license: apache-2.0
language:
- en
datasets:
- chtmp223/suri
---
# Suri-SFT
Suri-SFT is a fine-tuned version of mistralai/Mistral-7B-Instruct-v0.2 using supervised fine-tuning with LoRA. Please check [our paper](TODO) for more details on the method.
## 📒 Model Details
### Model Description
- **Language(s) (NLP):** English
- **License:** Apache-2.0
- **Finetuned from model:** [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2)
### Model Sources
- **Repository:** [Github repository](https://github.com/chtmp223/suri) -- contains code to reconstruct books3 subset.
- **Paper:** TODO
- **Demo:** [Website](https://chtmp223.github.io/suri)
## ⚠️ Getting Started
Use the code in [this repository](https://github.com/chtmp223/suri) for training and inference.
## 💻 Training Details
### Training Data
[chtmp223/suri](https://huggingface.co/datasets/chtmp223/suri)
### Training Procedure
| **Configurations** | **Values** |
|----------------------------------|--------------|
| Hardware (Training and Inference)| 4xA100s |
| Tracking | wandb |
| lora_r | 16 |
| lora_alpha | 16 |
| lora_dropout | 0.05 |
| gradient_accumulation_steps | 1 |
| gradient_checkpointing | True |
| learning_rate | 5.0e-5 |
| lr_scheduler_type | cosine |
| max_length | 15024 |
| max_completion_length | 15000 |
| max_prompt_length | 5000 |
| num_train_epochs | 2 |
| optim | adamw_torch |
| per_device_train_batch_size | 1 |
#### Software
Training code is adapted from [Alignment Handbook](https://github.com/huggingface/alignment-handbook) and [Trl](https://github.com/huggingface/trl).
## 🤗 Inference
```
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel, PeftConfig
from datasets import load_dataset
import torch
os.environ["TOKENIZERS_PARALLELISM"] = "False"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
model_name = "chtmp223/suri-sft"
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
config = PeftConfig.from_pretrained(model_name)
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
model = PeftModel.from_pretrained(base_model, model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
prompt = [
{
"role": "user",
"content": user_prompt,
}
]
input_context = tokenizer.apply_chat_template(
prompt, add_generation_prompt=True, tokenize=False
)
input_ids = tokenizer.encode(
input_context, return_tensors="pt", add_special_tokens=False
).to(model.device)
output = model.generate(
input_ids, max_length=10000, do_sample=True, use_cache=True
).cpu()
print(tokenizer.decode(output[0]))
```
## 📜 Citation
```
TODO
```
### ⚙️ Framework versions
- PEFT 0.11.1 |