PEFT
Safetensors
French
eltorio commited on
Commit
f4f7ed0
1 Parent(s): 38e0f32

Upload autoeval-training-llama-3-2-3b.ipynb

Browse files
Files changed (1) hide show
  1. autoeval-training-llama-3-2-3b.ipynb +635 -0
autoeval-training-llama-3-2-3b.ipynb ADDED
@@ -0,0 +1,635 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# Autoeval"
8
+ ]
9
+ },
10
+ {
11
+ "cell_type": "code",
12
+ "execution_count": null,
13
+ "metadata": {
14
+ "execution": {
15
+ "iopub.execute_input": "2024-12-02T11:56:29.397635Z",
16
+ "iopub.status.busy": "2024-12-02T11:56:29.397111Z",
17
+ "iopub.status.idle": "2024-12-02T11:56:29.411850Z",
18
+ "shell.execute_reply": "2024-12-02T11:56:29.410508Z",
19
+ "shell.execute_reply.started": "2024-12-02T11:56:29.397590Z"
20
+ },
21
+ "trusted": true
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "import os\n",
26
+ "source_model = \"unsloth/Llama-3.2-3B-Instruct\"\n",
27
+ "destination_model = \"Llama-3.2-3B-appreciation\"\n",
28
+ "dataset_url = \"eltorio/appreciation\"\n",
29
+ "epoch = 5\n",
30
+ "push_to_hub = True if os.path.exists('/kaggle/working') else False\n",
31
+ "output_directory = '/kaggle/working' if os.path.exists('/kaggle/working') else './'\n",
32
+ "kaggle_model = f\"eltorio/{destination_model.lower()}/transformers/default\""
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "markdown",
37
+ "metadata": {},
38
+ "source": [
39
+ "## Install the required libraries"
40
+ ]
41
+ },
42
+ {
43
+ "cell_type": "code",
44
+ "execution_count": null,
45
+ "metadata": {
46
+ "trusted": true
47
+ },
48
+ "outputs": [],
49
+ "source": [
50
+ "%%capture\n",
51
+ "!pip install -U \"safetensors>=0.4.5\"\n",
52
+ "!pip install -U tensorflow\n",
53
+ "!pip install -U \"https://github.com/bitsandbytes-foundation/bitsandbytes/releases/download/continuous-release_main/bitsandbytes-0.44.2.dev0-py3-none-manylinux_2_24_x86_64.whl\"\n",
54
+ "!pip install -U git+https://github.com/huggingface/transformers.git\n",
55
+ "!pip install huggingface_hub[cli] accelerate datasets peft\n",
56
+ "!pip install pip3-autoremove\n",
57
+ "!pip-autoremove torch torchvision torchaudio -y\n",
58
+ "!pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121\n",
59
+ "!pip install unsloth\n",
60
+ "!pip uninstall unsloth -y && pip install --upgrade --no-cache-dir --no-deps git+https://github.com/unslothai/unsloth.git\n",
61
+ "!pip install tf-keras"
62
+ ]
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "metadata": {},
67
+ "source": [
68
+ "### Log in Kaggle"
69
+ ]
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "execution_count": null,
74
+ "metadata": {
75
+ "trusted": true
76
+ },
77
+ "outputs": [],
78
+ "source": [
79
+ "import os\n",
80
+ "import json\n",
81
+ "if not os.path.exists('/kaggle/.kaggle/kaggle.json'):\n",
82
+ " try:\n",
83
+ " from kaggle_secrets import UserSecretsClient\n",
84
+ " user_secrets = UserSecretsClient()\n",
85
+ " KAGGLE_JSON = user_secrets.get_secret(\"KAGGLE_JSON\")\n",
86
+ " except:\n",
87
+ " KAGGLE_JSON = os.getenv(\"KAGGLE_JSON\")\n",
88
+ "\n",
89
+ " kaggle_dir = os.path.expanduser(\"~/.kaggle\")\n",
90
+ " kaggle_file = os.path.join(kaggle_dir, \"kaggle.json\")\n",
91
+ "\n",
92
+ " os.makedirs(kaggle_dir, exist_ok=True)\n",
93
+ "\n",
94
+ " with open(kaggle_file, 'w') as file:\n",
95
+ " json.dump(KAGGLE_JSON, file)"
96
+ ]
97
+ },
98
+ {
99
+ "cell_type": "markdown",
100
+ "metadata": {},
101
+ "source": [
102
+ "### Login WandB"
103
+ ]
104
+ },
105
+ {
106
+ "cell_type": "code",
107
+ "execution_count": null,
108
+ "metadata": {},
109
+ "outputs": [],
110
+ "source": [
111
+ "import wandb\n",
112
+ "try:\n",
113
+ " from kaggle_secrets import UserSecretsClient\n",
114
+ " user_secrets = UserSecretsClient()\n",
115
+ " WANDB_API_KEY = user_secrets.get_secret(\"WANDB_API_KEY\")\n",
116
+ " os.environ[\"WANDB_API_KEY\"] = WANDB_API_KEY\n",
117
+ "except:\n",
118
+ " if os.getenv(\"WANDB_API_KEY\") is None:\n",
119
+ " os.environ[\"WANDB_API_KEY\"] = input(\"Enter your W&B API key: \")\n",
120
+ "\n",
121
+ "if not wandb.login():\n",
122
+ " raise Exception(\"Can't login to W&B\")\n",
123
+ "else:\n",
124
+ " print(\"Logged in to W&B\")\n",
125
+ " os.environ[\"WANDB_PROJECT\"]=destination_model"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "markdown",
130
+ "metadata": {},
131
+ "source": [
132
+ "### Log in Hugging hub"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "execution_count": null,
138
+ "metadata": {
139
+ "trusted": true
140
+ },
141
+ "outputs": [],
142
+ "source": [
143
+ "from huggingface_hub import login\n",
144
+ "import os\n",
145
+ "\n",
146
+ "try:\n",
147
+ " from kaggle_secrets import UserSecretsClient\n",
148
+ " user_secrets = UserSecretsClient()\n",
149
+ " HF_TOKEN = user_secrets.get_secret(\"HF_TOKEN\")\n",
150
+ " os.environ[\"HF_TOKEN\"] = HF_TOKEN\n",
151
+ "except:\n",
152
+ " if not os.getenv(\"HF_TOKEN\"):\n",
153
+ " raise ValueError(\"You need to set the HF_TOKEN environment variable.\")\n",
154
+ " HF_TOKEN = os.getenv(\"HF_TOKEN\")\n",
155
+ "\n",
156
+ "print(f\"Login with {HF_TOKEN}\")\n",
157
+ "login(\n",
158
+ " token=HF_TOKEN,\n",
159
+ " add_to_git_credential=False\n",
160
+ ")"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "markdown",
165
+ "metadata": {},
166
+ "source": [
167
+ "## Training parameters"
168
+ ]
169
+ },
170
+ {
171
+ "cell_type": "code",
172
+ "execution_count": null,
173
+ "metadata": {
174
+ "trusted": true
175
+ },
176
+ "outputs": [],
177
+ "source": [
178
+ "from unsloth import FastLanguageModel\n",
179
+ "import torch\n",
180
+ "\n",
181
+ "max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!\n",
182
+ "dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+\n",
183
+ "load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.\n"
184
+ ]
185
+ },
186
+ {
187
+ "cell_type": "markdown",
188
+ "metadata": {},
189
+ "source": [
190
+ "## Load the source model"
191
+ ]
192
+ },
193
+ {
194
+ "cell_type": "code",
195
+ "execution_count": null,
196
+ "metadata": {
197
+ "trusted": true
198
+ },
199
+ "outputs": [],
200
+ "source": [
201
+ "model, tokenizer = FastLanguageModel.from_pretrained(\n",
202
+ " model_name = source_model, # or choose \"unsloth/Llama-3.2-1B-Instruct\"\n",
203
+ " max_seq_length = max_seq_length,\n",
204
+ " dtype = dtype,\n",
205
+ " load_in_4bit = load_in_4bit,\n",
206
+ " token = HF_TOKEN,\n",
207
+ ")"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "metadata": {},
213
+ "source": [
214
+ "## Add the Peft model"
215
+ ]
216
+ },
217
+ {
218
+ "cell_type": "code",
219
+ "execution_count": null,
220
+ "metadata": {
221
+ "trusted": true
222
+ },
223
+ "outputs": [],
224
+ "source": [
225
+ "model = FastLanguageModel.get_peft_model(\n",
226
+ " model,\n",
227
+ " r = 16, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128\n",
228
+ " target_modules = [\"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n",
229
+ " \"gate_proj\", \"up_proj\", \"down_proj\",],\n",
230
+ " lora_alpha = 16,\n",
231
+ " lora_dropout = 0, # Supports any, but = 0 is optimized\n",
232
+ " bias = \"none\", # Supports any, but = \"none\" is optimized\n",
233
+ " # [NEW] \"unsloth\" uses 30% less VRAM, fits 2x larger batch sizes!\n",
234
+ " use_gradient_checkpointing = \"unsloth\", # True or \"unsloth\" for very long context\n",
235
+ " random_state = 3407,\n",
236
+ " use_rslora = False, # We support rank stabilized LoRA\n",
237
+ " loftq_config = None, # And LoftQ\n",
238
+ ")"
239
+ ]
240
+ },
241
+ {
242
+ "cell_type": "markdown",
243
+ "metadata": {},
244
+ "source": [
245
+ "### Read the data"
246
+ ]
247
+ },
248
+ {
249
+ "cell_type": "code",
250
+ "execution_count": null,
251
+ "metadata": {
252
+ "execution": {
253
+ "iopub.execute_input": "2024-12-02T11:56:34.316028Z",
254
+ "iopub.status.busy": "2024-12-02T11:56:34.315647Z",
255
+ "iopub.status.idle": "2024-12-02T11:56:36.257132Z",
256
+ "shell.execute_reply": "2024-12-02T11:56:36.255969Z",
257
+ "shell.execute_reply.started": "2024-12-02T11:56:34.315997Z"
258
+ },
259
+ "trusted": true
260
+ },
261
+ "outputs": [],
262
+ "source": [
263
+ "from datasets import load_dataset\n",
264
+ "dataset = load_dataset(dataset_url)\n",
265
+ "dataset['train']"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "markdown",
270
+ "metadata": {},
271
+ "source": [
272
+ "### Create the messages from the data\n",
273
+ "\n",
274
+ "The data is in the form of a csv file with the following columns:\n",
275
+ "\n",
276
+ "```csv\n",
277
+ "\n",
278
+ "Id,redoublant,matière,trimestre,note 1er trimestre,note 2ème trimestre,note 3ème trimestre,comportement 0-10,participation 0-10,travail 0-10,commentaire\n",
279
+ "\n",
280
+ "0,0,,1,\"Mauvais trimestre, manque de travail\",5.0,,,5.0,5.0,5.0,X a beaucoup de difficultés dues à des lacunes mais aussi à un manque de travail qui ne permet pas de les combler. Il faut s'y mettre au prochain trimestre.\n",
281
+ "\n",
282
+ "```\n",
283
+ "\n",
284
+ "We need to create HuggingFace's normal multiturn format "
285
+ ]
286
+ },
287
+ {
288
+ "cell_type": "code",
289
+ "execution_count": null,
290
+ "metadata": {
291
+ "execution": {
292
+ "iopub.execute_input": "2024-12-02T11:56:45.923298Z",
293
+ "iopub.status.busy": "2024-12-02T11:56:45.922896Z",
294
+ "iopub.status.idle": "2024-12-02T11:56:45.933706Z",
295
+ "shell.execute_reply": "2024-12-02T11:56:45.932503Z",
296
+ "shell.execute_reply.started": "2024-12-02T11:56:45.923263Z"
297
+ },
298
+ "trusted": true
299
+ },
300
+ "outputs": [],
301
+ "source": [
302
+ "def create_training_turn(row):\n",
303
+ " trimestre = row['trimestre']\n",
304
+ " redoublant = 'redoublant ' if row['redoublant'] == 1 else ''\n",
305
+ " moyenne_1 = row['note 1er trimestre'] if not isinstance(row['note 1er trimestre'],float|int) else 'N/A'\n",
306
+ " moyenne_2 = row['note 2ème trimestre'] if not isinstance(row['note 2ème trimestre'],float|int) else 'N/A'\n",
307
+ " moyenne_3 = row['note 3ème trimestre'] if not isinstance(row['note 3ème trimestre'],float|int) else 'N/A'\n",
308
+ " comportement = row['comportement 0-10']\n",
309
+ " participation = row['participation 0-10']\n",
310
+ " travail = row['travail 0-10']\n",
311
+ " system_prompt = \"Vous êtes une IA assistant les enseignants d'histoire-géographie en rédigeant à leur place une appréciation personnalisée pour leur élève en fonction de ses performances. Votre appréciation doit être en français formel et impersonnel. Votre appréciation doit être bienveillante, constructive, et aider l'élève à comprendre ses points forts et les axes d'amélioration. Votre appréciation doit comporter de 8 à 250 caractères. Votre appréciation ne doit jamais comporter les valeurs des notes. \"\n",
312
+ "\n",
313
+ " if trimestre == 1:\n",
314
+ " trimestre_full = \"premier trimestre\"\n",
315
+ " user_input = f\"Veuillez rédiger une appréciation en moins de 250 caractères pour le {trimestre_full} pour cet élève {redoublant}qui a eu {moyenne_1} de moyenne, j'ai évalué son comportement à {comportement}/10, sa participation à {participation}/10 et son travail à {travail}/10. Les notes ne doivent pas apparaître dans l'appréciation.\"\n",
316
+ " elif trimestre == 2:\n",
317
+ " trimestre_full = \"deuxième trimestre\"\n",
318
+ " user_input = f\"Veuillez rédiger une appréciation en moins de 250 caractères pour le {trimestre_full} pour cet élève {redoublant}qui a eu {moyenne_2} de moyenne ce trimestre et {moyenne_1} au premier trimestre, j'ai évalué son comportement à {comportement}/10, sa participation à {participation}/10 et son travail à {travail}/10. Les notes ne doivent pas apparaître dans l'appréciation.\"\n",
319
+ " elif trimestre == 3:\n",
320
+ " trimestre_full = \"troisième trimestre\"\n",
321
+ " user_input = f\"Veuillez rédiger une appréciation en moins de 250 caractères pour le {trimestre_full} pour cet élève {redoublant}qui a eu {moyenne_3} de moyenne ce trimestre, {moyenne_2} au deuxième trimestre et {moyenne_1} au premier trimestre, j'ai évalué son comportement à {comportement}/10, sa participation à {participation}/10 et son travail à {travail}/10. Les notes ne doivent pas apparaître dans l'appréciation.\"\n",
322
+ "\n",
323
+ " assistant_response = row['commentaire']\n",
324
+ "\n",
325
+ " return {\"conversation\":[\n",
326
+ " {\"role\": \"system\", \"content\":system_prompt},\n",
327
+ " {\"role\": \"user\", \"content\":user_input},\n",
328
+ " {\"role\": \"assistant\", \"content\":assistant_response}\n",
329
+ " ]}\n"
330
+ ]
331
+ },
332
+ {
333
+ "cell_type": "markdown",
334
+ "metadata": {},
335
+ "source": [
336
+ "### Check the function"
337
+ ]
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "execution_count": null,
342
+ "metadata": {
343
+ "execution": {
344
+ "iopub.execute_input": "2024-12-02T11:56:50.058458Z",
345
+ "iopub.status.busy": "2024-12-02T11:56:50.058002Z",
346
+ "iopub.status.idle": "2024-12-02T11:56:50.066899Z",
347
+ "shell.execute_reply": "2024-12-02T11:56:50.065730Z",
348
+ "shell.execute_reply.started": "2024-12-02T11:56:50.058406Z"
349
+ },
350
+ "trusted": true
351
+ },
352
+ "outputs": [],
353
+ "source": [
354
+ "test_row = dataset['train'][68]\n",
355
+ "create_training_turn(test_row)"
356
+ ]
357
+ },
358
+ {
359
+ "cell_type": "markdown",
360
+ "metadata": {},
361
+ "source": [
362
+ "### Create the dataset"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "code",
367
+ "execution_count": null,
368
+ "metadata": {
369
+ "execution": {
370
+ "iopub.execute_input": "2024-12-02T11:56:58.639949Z",
371
+ "iopub.status.busy": "2024-12-02T11:56:58.639529Z",
372
+ "iopub.status.idle": "2024-12-02T11:56:59.178999Z",
373
+ "shell.execute_reply": "2024-12-02T11:56:59.177678Z",
374
+ "shell.execute_reply.started": "2024-12-02T11:56:58.639912Z"
375
+ },
376
+ "trusted": true
377
+ },
378
+ "outputs": [],
379
+ "source": [
380
+ "multi_turn_dataset = dataset.map(create_training_turn)\n",
381
+ "multi_turn_dataset['train'][68]"
382
+ ]
383
+ },
384
+ {
385
+ "cell_type": "markdown",
386
+ "metadata": {},
387
+ "source": [
388
+ "## Tokenize the data"
389
+ ]
390
+ },
391
+ {
392
+ "cell_type": "code",
393
+ "execution_count": null,
394
+ "metadata": {
395
+ "trusted": true
396
+ },
397
+ "outputs": [],
398
+ "source": [
399
+ "from unsloth.chat_templates import get_chat_template\n",
400
+ "\n",
401
+ "tokenizer = get_chat_template(\n",
402
+ " tokenizer,\n",
403
+ " chat_template = \"llama-3.1\",\n",
404
+ ")\n",
405
+ "\n",
406
+ "def formatting_prompts_func(messages):\n",
407
+ " convos = messages[\"conversation\"]\n",
408
+ " texts = [tokenizer.apply_chat_template(convo, tokenize = False, add_generation_prompt = False) for convo in convos]\n",
409
+ " return { \"text\" : texts, }\n",
410
+ "pass\n",
411
+ "\n",
412
+ "multi_turn_dataset = multi_turn_dataset.map(\n",
413
+ " formatting_prompts_func,\n",
414
+ " batched=True,\n",
415
+ ")"
416
+ ]
417
+ },
418
+ {
419
+ "cell_type": "markdown",
420
+ "metadata": {},
421
+ "source": [
422
+ "### Check the tokenized data"
423
+ ]
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "execution_count": null,
428
+ "metadata": {
429
+ "execution": {
430
+ "iopub.execute_input": "2024-12-02T11:57:11.739989Z",
431
+ "iopub.status.busy": "2024-12-02T11:57:11.739580Z",
432
+ "iopub.status.idle": "2024-12-02T11:57:12.535408Z",
433
+ "shell.execute_reply": "2024-12-02T11:57:12.533818Z",
434
+ "shell.execute_reply.started": "2024-12-02T11:57:11.739953Z"
435
+ },
436
+ "trusted": true
437
+ },
438
+ "outputs": [],
439
+ "source": [
440
+ "multi_turn_dataset[\"train\"][\"text\"][278]"
441
+ ]
442
+ },
443
+ {
444
+ "cell_type": "markdown",
445
+ "metadata": {},
446
+ "source": [
447
+ "### Parmeters for training"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "metadata": {
454
+ "trusted": true
455
+ },
456
+ "outputs": [],
457
+ "source": [
458
+ "from trl import SFTTrainer\n",
459
+ "from transformers import TrainingArguments, DataCollatorForSeq2Seq\n",
460
+ "from unsloth import is_bfloat16_supported\n",
461
+ "\n",
462
+ "trainer = SFTTrainer(\n",
463
+ " model = model,\n",
464
+ " tokenizer = tokenizer,\n",
465
+ " train_dataset = multi_turn_dataset[\"train\"],\n",
466
+ " eval_dataset=multi_turn_dataset[\"validation\"],\n",
467
+ " dataset_text_field = \"text\",\n",
468
+ "\n",
469
+ " max_seq_length = max_seq_length,\n",
470
+ " data_collator = DataCollatorForSeq2Seq(tokenizer = tokenizer),\n",
471
+ " dataset_num_proc = 2,\n",
472
+ " packing = False, # Can make training 5x faster for short sequences.\n",
473
+ " args = TrainingArguments(\n",
474
+ " per_device_train_batch_size = 2,\n",
475
+ " gradient_accumulation_steps = 4,\n",
476
+ " warmup_steps = 5,\n",
477
+ " num_train_epochs = epoch, # Set this for 1 full training run.\n",
478
+ " eval_strategy=\"epoch\",\n",
479
+ " save_strategy=\"epoch\",\n",
480
+ " logging_strategy=\"epoch\",\n",
481
+ " # max_steps = 60,\n",
482
+ " learning_rate = 2e-4,\n",
483
+ " fp16 = not is_bfloat16_supported(),\n",
484
+ " bf16 = is_bfloat16_supported(),\n",
485
+ " logging_steps = 1,\n",
486
+ " optim = \"adamw_8bit\",\n",
487
+ " weight_decay = 0.01,\n",
488
+ " lr_scheduler_type = \"linear\",\n",
489
+ " seed = 3407,\n",
490
+ " output_dir = output_directory,\n",
491
+ " report_to = \"wandb\", # Use this for WandB etc\n",
492
+ " push_to_hub = push_to_hub,\n",
493
+ " hub_model_id = destination_model\n",
494
+ " ),\n",
495
+ "\n",
496
+ ")"
497
+ ]
498
+ },
499
+ {
500
+ "cell_type": "code",
501
+ "execution_count": null,
502
+ "metadata": {
503
+ "trusted": true
504
+ },
505
+ "outputs": [],
506
+ "source": [
507
+ "from unsloth.chat_templates import train_on_responses_only\n",
508
+ "\n",
509
+ "trainer = train_on_responses_only(\n",
510
+ " trainer,\n",
511
+ " instruction_part = \"<|start_header_id|>user<|end_header_id|>\\n\\n\",\n",
512
+ " response_part = \"<|start_header_id|>assistant<|end_header_id|>\\n\\n\",\n",
513
+ ")"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "code",
518
+ "execution_count": null,
519
+ "metadata": {
520
+ "trusted": true
521
+ },
522
+ "outputs": [],
523
+ "source": [
524
+ "tokenizer.decode(trainer.train_dataset[5][\"input_ids\"])"
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "code",
529
+ "execution_count": null,
530
+ "metadata": {
531
+ "trusted": true
532
+ },
533
+ "outputs": [],
534
+ "source": [
535
+ "space = tokenizer(\" \", add_special_tokens = False).input_ids[0]\n",
536
+ "tokenizer.decode([space if x == -100 else x for x in trainer.train_dataset[5][\"labels\"]])"
537
+ ]
538
+ },
539
+ {
540
+ "cell_type": "markdown",
541
+ "metadata": {},
542
+ "source": [
543
+ "### Create the model"
544
+ ]
545
+ },
546
+ {
547
+ "cell_type": "code",
548
+ "execution_count": null,
549
+ "metadata": {
550
+ "trusted": true
551
+ },
552
+ "outputs": [],
553
+ "source": [
554
+ "trainer_stats = trainer.train()"
555
+ ]
556
+ },
557
+ {
558
+ "cell_type": "markdown",
559
+ "metadata": {},
560
+ "source": [
561
+ "## Publish to Kaggle"
562
+ ]
563
+ },
564
+ {
565
+ "cell_type": "code",
566
+ "execution_count": null,
567
+ "metadata": {
568
+ "trusted": true
569
+ },
570
+ "outputs": [],
571
+ "source": [
572
+ "import kagglehub\n",
573
+ "import os\n",
574
+ "import re\n",
575
+ "\n",
576
+ "def get_latest_checkpoint(directory):\n",
577
+ " # Liste tous les répertoires dans le répertoire donné\n",
578
+ " subdirs = [d for d in os.listdir(directory) if os.path.isdir(os.path.join(directory, d))]\n",
579
+ " # Filtre les répertoires qui correspondent au format \"checkpoint_xxx\"\n",
580
+ " checkpoint_dirs = [d for d in subdirs if re.match(r'checkpoint-\\d+', d)]\n",
581
+ " print(checkpoint_dirs)\n",
582
+ " # Extrait les valeurs numériques et trouve la plus élevée\n",
583
+ " max_checkpoint = max(checkpoint_dirs, key=lambda x: int(x.split('-')[1]))\n",
584
+ " print(max_checkpoint)\n",
585
+ " return os.path.join(directory, max_checkpoint)\n",
586
+ "\n",
587
+ "\n",
588
+ "latest_checkpoint = get_latest_checkpoint(output_directory)\n",
589
+ "print(f'The newest model is : {latest_checkpoint}')\n",
590
+ "\n",
591
+ "kagglehub.login()\n",
592
+ "kagglehub.model_upload(\n",
593
+ " handle= kaggle_model,\n",
594
+ " local_model_dir = latest_checkpoint\n",
595
+ ")\n"
596
+ ]
597
+ }
598
+ ],
599
+ "metadata": {
600
+ "kaggle": {
601
+ "accelerator": "none",
602
+ "dataSources": [
603
+ {
604
+ "datasetId": 6161747,
605
+ "sourceId": 10010677,
606
+ "sourceType": "datasetVersion"
607
+ }
608
+ ],
609
+ "dockerImageVersionId": 30787,
610
+ "isGpuEnabled": false,
611
+ "isInternetEnabled": true,
612
+ "language": "python",
613
+ "sourceType": "notebook"
614
+ },
615
+ "kernelspec": {
616
+ "display_name": "Python 3 (ipykernel)",
617
+ "language": "python",
618
+ "name": "python3"
619
+ },
620
+ "language_info": {
621
+ "codemirror_mode": {
622
+ "name": "ipython",
623
+ "version": 3
624
+ },
625
+ "file_extension": ".py",
626
+ "mimetype": "text/x-python",
627
+ "name": "python",
628
+ "nbconvert_exporter": "python",
629
+ "pygments_lexer": "ipython3",
630
+ "version": "3.12.7"
631
+ }
632
+ },
633
+ "nbformat": 4,
634
+ "nbformat_minor": 4
635
+ }