BramVanroy commited on
Commit
1794d9c
·
verified ·
1 Parent(s): e86ae6c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +123 -21
README.md CHANGED
@@ -1,50 +1,152 @@
1
  ---
2
  license: cc-by-nc-4.0
3
- base_model: BramVanroy/GEITje-ultra-sft
4
  tags:
5
  - alignment-handbook
6
  - generated_from_trainer
7
  - trl
8
  - dpo
9
- - generated_from_trainer
10
  datasets:
11
  - BramVanroy/ultra_feedback_dutch
12
  model-index:
13
- - name: GEITje-ultra-dpo-5e-7lr-128tbs-0.1b
14
  results: []
 
 
 
15
  ---
16
 
17
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
18
- should probably proofread and complete it, then remove this comment. -->
 
 
19
 
20
- # GEITje-ultra-dpo-5e-7lr-128tbs-0.1b
 
 
21
 
22
- This model is a fine-tuned version of [BramVanroy/GEITje-ultra-sft](https://huggingface.co/BramVanroy/GEITje-ultra-sft) on the BramVanroy/ultra_feedback_dutch dataset.
23
- It achieves the following results on the evaluation set:
24
- - Loss: 0.0138
25
- - Rewards/chosen: -2.1351
26
- - Rewards/rejected: -13.8922
27
- - Rewards/accuracies: 0.9950
28
- - Rewards/margins: 11.7570
29
- - Logps/rejected: -565.1809
30
- - Logps/chosen: -519.8008
31
- - Logits/rejected: -3.0261
32
- - Logits/chosen: -2.9779
33
 
34
  ## Model description
35
 
36
- More information needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  ## Intended uses & limitations
39
 
40
- More information needed
 
 
41
 
42
  ## Training and evaluation data
43
 
44
- More information needed
45
 
 
 
 
46
  ## Training procedure
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  ### Training hyperparameters
49
 
50
  The following hyperparameters were used during training:
@@ -77,4 +179,4 @@ The following hyperparameters were used during training:
77
  - Transformers 4.36.2
78
  - Pytorch 2.1.2+cu121
79
  - Datasets 2.14.6
80
- - Tokenizers 0.15.0
 
1
  ---
2
  license: cc-by-nc-4.0
3
+ base_model: BramVanroy/GEITje-7B-ultra-sft
4
  tags:
5
  - alignment-handbook
6
  - generated_from_trainer
7
  - trl
8
  - dpo
9
+ - geitje
10
  datasets:
11
  - BramVanroy/ultra_feedback_dutch
12
  model-index:
13
+ - name: BramVanroy/GEITje-7B-ultra
14
  results: []
15
+ language:
16
+ - nl
17
+ pipeline_tag: conversational
18
  ---
19
 
20
+ <img src="https://huggingface.co/BramVanroy/GEITje-ultra/resolve/main/geitje-ultra-banner.png" alt="GEITje Ultra banner" width="800" style="margin-left:'auto' margin-right:'auto' display:'block'"/>
21
+
22
+
23
+ # GEITje 7B ultra
24
 
25
+ **A conversational model, aligned through AI feedback.**
26
+
27
+ This model is a fine-tuned version of [BramVanroy/GEITje-ultra-sft](https://huggingface.co/BramVanroy/GEITje-ultra-sft) on a synthetic DPO dataset of around 56M tokens that was generated with gpt-4-turbo and [Rijgersberg/GEITje-7B-chat](https://huggingface.co/Rijgersberg/GEITje-7B-chat) for Dutch.
28
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  ## Model description
31
 
32
+ This is a Dutch instruction/chat model ultimately based on Mistral and aligned with AI feedback via DPO. It is a DPO continuation of the SFT trained [BramVanroy/GEITje-ultra-sft](https://huggingface.co/BramVanroy/GEITje-ultra-sft), which in turn is based on [Rijgersberg/GEITje-7B](https://huggingface.co/Rijgersberg/GEITje-7B), which in turn is based on Mistral 7B and further pretrained on Dutch data. In (rather naive) [benchmarks](https://huggingface.co/spaces/BramVanroy/open_dutch_llm_leaderboard) it outperforms all the original GEITje models on average and ties with the powerful Zephyr model by Hugging Face. However, note that these benchmarks should be taken with a massive grain of salt (see the disclaimer below the benchmarks on that page).
33
+
34
+
35
+ ## Usage
36
+
37
+ One-off:
38
+
39
+ ```python
40
+ from transformers import pipeline, Conversation
41
+
42
+ # load_in_8bit: lower precision but saves a lot of GPU memory
43
+ # device_map=auto: loads the model across multiple GPUs
44
+ chatbot = pipeline("conversational", model="BramVanroy/GEITje-7B-ultra", model_kwargs={"load_in_8bit": True}, device_map="auto")
45
+
46
+ start_messages = [
47
+ {"role": "system", "content": "Je bent een grappige chatbot die Bert heet. Je maakt vaak mopjes."},
48
+ {"role": "user", "content": "Hallo, ik ben Bram. Ik wil vanavond graag een film kijken. Heb je enkele suggesties?"}
49
+ ]
50
+ conversation = Conversation(start_messages)
51
+ conversation = chatbot(conversation)
52
+ response = conversation.messages[-1]["content"]
53
+ print(response)
54
+ ```
55
+
56
+ Interactive conversation:
57
+
58
+ ```python
59
+ from transformers import pipeline, Conversation
60
+
61
+ # load_in_8bit: lower precision but saves a lot of memory
62
+ # device_map=auto: loads the model across multiple GPUs
63
+ # attn_implementation: uses flash attention, if your device supports it - otherwise remove it
64
+ chatbot = pipeline("conversational", model="BramVanroy/GEITje-7B-ultra", model_kwargs={"load_in_8bit": True, "attn_implementation": "flash_attention_2"}, device_map="auto")
65
+
66
+ while (system_message := input("System message ('q' to quit): ")) != "q":
67
+ start_messages = [
68
+ {"role": "system", "content": system_message},
69
+ ]
70
+ conversation = Conversation(start_messages)
71
+ while (user_input := input("User ('r' to reset): ")) != "r":
72
+ conversation.add_user_input(user_input)
73
+ conversation = chatbot(conversation)
74
+ response = conversation.messages[-1]["content"]
75
+ print("Assistant:", response)
76
+
77
+ ```
78
 
79
  ## Intended uses & limitations
80
 
81
+ Although the model has been aligned with gpt-4-turbo output, which has strong content filters, the model could still generate wrong, misleading, and potentially even offensive content. Use at your own risk.
82
+
83
+ Because the model was trained on synthetic data created with OpenAI/Azure services, this model cannot be used for commercial purposes.
84
 
85
  ## Training and evaluation data
86
 
87
+ The training data consists of a synthetic dataset based on [UltraFeedback binarized](https://huggingface.co/datasets/HuggingFaceH4/ultrafeedback_binarized) created with gpt-4-turbo and geitje-chat. A given prompt, translated from the original dataset, is given to the two models who then generated an answer. Then, gpt-4-turbo is always selected as the best answer which DPO will optimise for. While this is not completely fair, I did not have the budget to actually have gpt-4 rate both replies. Furthermore, while an impressive model, GEITje chat still seems behind gpt-4-turbo in the testing that I have done.
88
 
89
+ In total the dataset consists of 56,137,090 tokens (combination of prompt + rejected + chosen) and a test set of 6,178,969 tokens (11.00%).
90
+
91
+
92
  ## Training procedure
93
 
94
+ The great [alignment handbook](https://github.com/huggingface/alignment-handbook/) was used for training, with a custom slurm script for compatibility with our cluster. It was trained in full, without LoRA or other adapters.
95
+
96
+ The model was trained in bfloat16 with flash attention 2 on two nodes of four A100 80GB each for around 11 hours. I thank the [Flemish Super Computer](https://www.vscentrum.be/compute) for their compute.
97
+
98
+ For conversational usage, the model relies on the Zephyr chat template, which is compatible with system messages. A small portion of the data of *-sft contained system messages, so it is assumed the model can handle system messages at least a little bit.
99
+
100
+ In earlier iterations I found that using the alignment handbook's defaults (beta=0.01) led to poor results (hallucinations of random tokens). After investigating, it seems that such a low beta does not work well for this dataset as it gives the model too much room to deviate from its initial base model. After a [hyperparameter search](https://huggingface.co/posts/BramVanroy/492522322273746) and manual analysis of the resulting metrics, I selected the current model as the best one, with a beta of 0.1.
101
+
102
+ Recipe used with the handbook:
103
+
104
+ ```yaml
105
+ # Model arguments
106
+ model_name_or_path: BramVanroy/GEITje-7B-ultra-sft
107
+ model_revision: main
108
+ torch_dtype: bfloat16
109
+ use_flash_attention_2: true
110
+
111
+ # Data training arguments
112
+ # For definitions, see: src/h4/training/config.py
113
+ dataset_mixer:
114
+ BramVanroy/ultra_feedback_dutch: 1.0
115
+ dataset_splits:
116
+ - train_prefs
117
+ - test_prefs
118
+ preprocessing_num_workers: 8
119
+
120
+ # DPOTrainer arguments
121
+ bf16: true
122
+ beta: 0.1
123
+ do_eval: true
124
+ evaluation_strategy: steps
125
+ eval_steps: 100
126
+ gradient_accumulation_steps: 4
127
+ gradient_checkpointing: true
128
+ gradient_checkpointing_kwargs:
129
+ use_reentrant: False
130
+ hub_model_id: BramVanroy/GEITje-ultra
131
+ learning_rate: 5.0e-7
132
+ log_level: info
133
+ logging_steps: 10
134
+ lr_scheduler_type: cosine
135
+ max_length: 2048
136
+ max_prompt_length: 1536
137
+ num_train_epochs: 1
138
+ optim: adamw_torch
139
+ output_dir: data/GEITje-ultra
140
+ per_device_train_batch_size: 4
141
+ per_device_eval_batch_size: 4
142
+ push_to_hub: true
143
+ save_strategy: "steps"
144
+ save_steps: 100
145
+ save_total_limit: 3
146
+ seed: 42
147
+ warmup_ratio: 0.1
148
+ ```
149
+
150
  ### Training hyperparameters
151
 
152
  The following hyperparameters were used during training:
 
179
  - Transformers 4.36.2
180
  - Pytorch 2.1.2+cu121
181
  - Datasets 2.14.6
182
+ - Tokenizers 0.15.0