{ "cells": [ { "cell_type": "markdown", "id": "5f239612-620e-4430-8685-9fdc6b179b41", "metadata": {}, "source": [ "# Training PEFT models with new tokens being added to the embedding layers and tokenizer\n", "\n", "In this example, we will learn how to train a LoRA model when adding new tokens to the tokenizer and model. \n", "This is a common usecase when doing the following:\n", "1. Instruction finetuning with new tokens beind added such as `<|user|>`, `<|assistant|>`, `<|system|>`, ``, `` to properly format the conversations\n", "2. Finetuning on a specific language wherein language spoecific tokens are added, e.g., korean tokens being added to vocabulary for finetuning LLM on Korean datasets.\n", "3. Instruction finetuning to return outputs in certain format to enable agent behaviour new tokens such as `<|FUNCTIONS|>`, `<|BROWSE|>`, `<|TEXT2IMAGE|>`, `<|ASR|>`, `<|TTS|>`, `<|GENERATECODE|>`, `<|RAG|>`.\n", "\n", "In such cases, you add the Embedding modules to the LORA `target_modules`. PEFT will take care of saving the embedding layers with the new added tokens along with the adapter weights that were trained on the specific initialization of the embeddings weights of the added tokens." ] }, { "cell_type": "markdown", "id": "b27c55e8-edaa-4059-90bc-d6096d596902", "metadata": {}, "source": [ "Let's import the necessary libraries" ] }, { "cell_type": "code", "execution_count": 1, "id": "6f864c90", "metadata": {}, "outputs": [], "source": [ "import os\n", "\n", "os.environ[\"CUDA_VISIBLE_DEVICES\"] = \"3\"\n", "os.environ[\"WANDB_PROJECT\"] = \"PeftExamples\"\n", "import transformers\n", "from peft import (\n", " LoraConfig,\n", " PeftConfig,\n", " PeftModel,\n", " get_peft_model,\n", ")\n", "from peft import prepare_model_for_kbit_training as prepare_model_for_int8_training\n", "from transformers import (\n", " AutoModelForCausalLM,\n", " AutoTokenizer,\n", " HfArgumentParser,\n", " TrainingArguments,\n", " Trainer,\n", " default_data_collator,\n", ")\n", "import torch\n", "from dataclasses import dataclass, field\n", "from typing import Optional\n", "from dataclass_csv import DataclassReader\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "from enum import Enum" ] }, { "cell_type": "markdown", "id": "74950a3f-bb63-4ce5-9e2b-1b83f92b13a2", "metadata": {}, "source": [ "## Prepare Model and Tokenizer" ] }, { "cell_type": "markdown", "id": "76763f5e-64b2-409b-8845-ae5589f8a4e0", "metadata": {}, "source": [ "Now, we will be adding 27 new tokens as well as replace the existing pad, bos and eos tokens of the model." ] }, { "cell_type": "code", "execution_count": 2, "id": "fd0498ea-547e-418d-bf13-c9abafdd5476", "metadata": {}, "outputs": [], "source": [ "class SpecialTokens(str, Enum):\n", " begin_target = \"<|begintarget|>\"\n", " end_target = \"<|endtarget|>\"\n", " begin_context = \"<|begincontext|>\"\n", " end_context = \"<|endcontext|>\"\n", " system = \"<|system|>\"\n", " user = \"<|user|>\"\n", " begin_last_user_utterance = \"<|beginlastuserutterance|>\"\n", " end_last_user_utterance = \"<|endlastuserutterance|>\"\n", " begin_dsts = \"<|begindsts|>\"\n", " end_dsts = \"<|enddsts|>\"\n", " begin_dst = \"<|begindst|>\"\n", " end_dst = \"<|enddst|>\"\n", " begin_belief = \"<|beginbelief|>\"\n", " end_belief = \"<|endbelief|>\"\n", " begin_response = \"<|beginresponse|>\"\n", " end_response = \"<|endresponse|>\"\n", " begin_action = \"<|beginaction|>\"\n", " end_action = \"<|endaction|>\"\n", " begin_user_action = \"<|beginuseraction|>\"\n", " end_user_action = \"<|enduseraction|>\"\n", " sys_actions = \"<|sysactions|>\"\n", " begin_intent = \"<|beginintent|>\"\n", " end_intent = \"<|endintent|>\"\n", " begin_requested_slots = \"<|beginrequestedslots|>\"\n", " end_requested_slots = \"<|endrequestedslots|>\"\n", " pad_token = \"<|pad|>\"\n", " bos_token = \"<|startoftext|>\"\n", "\n", " @classmethod\n", " def list(cls):\n", " return [c.value for c in cls]" ] }, { "cell_type": "markdown", "id": "ae4a4255-5f13-4eef-a024-4f1de0f2173b", "metadata": {}, "source": [ "We will be finetuning Mistral-7B model. Let's load the tokenizer and add the special tokens followed by loading the base model and resizzing the embedding layers to accomodate the newly added tokens." ] }, { "cell_type": "code", "execution_count": 3, "id": "f0eedef9", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "91c67b6377fc4dd7977bf544de784d51", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Loading checkpoint shards: 0%| | 0/2 [00:00<|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|pad|><|startoftext|><|begincontext|><|user|> Can you find me place to eat?<|system|> What kind of food would you like to have and where would you like me to search in?<|user|> Food kind of California will be perfect in SF.<|system|> There are 10 restaurants, Al's Place is one of the good restaurant in San Francisco.<|user|> Can you look for any other restaurant?<|system|> Alta Msp is one of the good restaurant in San Francisco.<|beginlastuserutterance|> Can you find me the address?<|endlastuserutterance|><|endcontext|><|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginrequestedslots|> Restaurants^street_address<|endrequestedslots|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->California<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^street_address~1275 Minnesota Street<|endaction|><|beginresponse|> The street address of the restaurant is 1275 Minnesota Street.<|endresponse|><|endtarget|><|endtarget|>\"" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenizer.decode(train_dataset[0][\"input_ids\"])" ] }, { "cell_type": "markdown", "id": "239d1c83-196d-471e-9bf7-5f36dafa9894", "metadata": {}, "source": [ "# Train the model" ] }, { "cell_type": "code", "execution_count": 10, "id": "ec80d6ee", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.\n", "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\n", "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33msmangrul\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" ] }, { "data": { "text/html": [ "Tracking run with wandb version 0.16.0" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Run data is saved locally in /raid/sourab/temp/wandb/run-20231128_230934-edod21gq" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "Syncing run ethereal-eon-1 to Weights & Biases (docs)
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View project at https://wandb.ai/smangrul/PeftExamples" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ " View run at https://wandb.ai/smangrul/PeftExamples/runs/edod21gq" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stderr", "output_type": "stream", "text": [ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...\n" ] }, { "data": { "text/html": [ "\n", "
\n", " \n", " \n", " [246/246 05:51, Epoch 2/2]\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
StepTraining Loss
105.189800
203.745500
302.371500
401.630200
501.302600
600.999400
700.704100
800.527800
900.509700
1000.382300
1100.318200
1200.323500
1300.263400
1400.290900
1500.277400
1600.232800
1700.223600
1800.229600
1900.233100
2000.210200
2100.245800
2200.197300
2300.210100
2400.209800

" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "TrainOutput(global_step=246, training_loss=0.8516577879587809, metrics={'train_runtime': 354.9013, 'train_samples_per_second': 5.556, 'train_steps_per_second': 0.693, 'total_flos': 4.318233532091597e+16, 'train_loss': 0.8516577879587809, 'epoch': 2.0})" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "training_args = TrainingArguments(\n", " output_dir=\"mistral_lora_clm_with_added_tokens\",\n", " num_train_epochs=2,\n", " save_total_limit=5,\n", " per_device_train_batch_size=8,\n", " warmup_steps=10,\n", " weight_decay=0.0001,\n", " dataloader_drop_last=True,\n", " bf16=True,\n", " logging_steps=10,\n", " learning_rate=1e-5,\n", " gradient_checkpointing=True,\n", " gradient_checkpointing_kwargs={\"use_reentrant\": False},\n", " remove_unused_columns=False,\n", " hub_model_id=\"smangrul/mistral_lora_clm_with_added_tokens\",\n", " push_to_hub=True,\n", " hub_private_repo=True,\n", ")\n", "trainer = Trainer(\n", " model=model,\n", " args=training_args,\n", " train_dataset=train_dataset,\n", " data_collator=default_data_collator,\n", ")\n", "# model.config.use_cache = False\n", "trainer.train()" ] }, { "cell_type": "markdown", "id": "7bc1cbed-4eb9-4aaa-ab5f-5b91bf432307", "metadata": {}, "source": [ "# Check the model output on a sample from evaluation dataset" ] }, { "cell_type": "code", "execution_count": 11, "id": "71851793", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "context=\"<|begincontext|><|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n", "\n", " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurants<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> Great, the phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n", "\n", " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n" ] } ], "source": [ "import random\n", "\n", "i = random.randint(0, len(dataset[\"test\"]))\n", "context = dataset[\"test\"][i][\"context\"]\n", "\n", "batch = tokenizer(context, return_tensors=\"pt\")\n", "batch = {k: v.to(\"cuda\") for k, v in batch.items()}\n", "model.eval()\n", "output_tokens = model.generate(\n", " **batch,\n", " max_new_tokens=256,\n", " do_sample=True,\n", " temperature=0.2,\n", " top_p=0.95,\n", " top_k=50,\n", " eos_token_id=tokenizer.eos_token_id,\n", " pad_token_id=tokenizer.pad_token_id,\n", ")\n", "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n", "target = dataset[\"test\"][i][\"target\"]\n", "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")" ] }, { "cell_type": "markdown", "id": "f940a660-2f7c-4a3a-b412-3f037aedb890", "metadata": {}, "source": [ "# Save the Adapter model " ] }, { "cell_type": "markdown", "id": "7ebe05e9-9b93-42f6-bba8-46b8cc3d100f", "metadata": {}, "source": [ "When the lora layers are applied to embedding layers, the corresponding base model embedding layers are also saved. " ] }, { "cell_type": "code", "execution_count": 12, "id": "3d7459ba-caa8-4f10-aa70-89be4541cbdf", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/raid/sourab/peft/src/peft/utils/save_and_load.py:128: UserWarning: Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\n", " warnings.warn(\"Setting `is_embedding_layer_resized` to `True` as embedding layers found in `target_modules`\")\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8d23186832014f209939ab83e79da011", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Upload 3 LFS files: 0%| | 0/3 [00:00<|user|>Can you find me a place to eat please?<|system|>Where at? And what kind of cuisine are you craving?<|user|>Somewhere in SF, and I am really craving Thai food at the moment!<|system|>I found a bunch of restaurants, there's actually 10 that you might like in San Francisco, one of them being Baan Thai House & Wine Bar<|user|>How can I reach them? And what's their address?<|system|>You can reach them by phone at 415-379-4505 and visit them at 534 Irving Street<|beginlastuserutterance|>Great, that restaurant sounds good<|endlastuserutterance|><|endcontext|>\" \n", "\n", " target_predicted='<|begintarget|><|begindsts|><|begindst|><|beginintent|> FindRestaurant<|endintent|><|beginbelief|> Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|> REQUEST->Restaurants^phone_number~|REQUEST->Restaurants^street_address~<|enduseraction|><|beginaction|> INFORM->Restaurants^phone_number~415-379-4505|INFORM->Restaurants^street_address~534 Irving Street<|endaction|><|beginresponse|> The phone number is 415-379-4505 and the address is 534 Irving Street<|endresponse|><|endtarget|>' \n", "\n", " target='<|begintarget|><|begindsts|><|begindst|><|beginintent|>FindRestaurants<|endintent|><|beginbelief|>Restaurants^city->SF~San Francisco|Restaurants^cuisine->Thai|Restaurants^restaurant_name->Baan Thai House & Wine Bar<|endbelief|><|enddst|><|enddsts|><|beginuseraction|>SELECT->Restaurants^~<|enduseraction|><|beginaction|>OFFER_INTENT->Restaurants^intent~ReserveRestaurant<|endaction|><|beginresponse|>Want me to book a table?<|endresponse|><|endtarget|>'\n" ] } ], "source": [ "from peft import PeftModel\n", "\n", "inference_model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " low_cpu_mem_usage=True,\n", " # use_flash_attention_2=True,\n", ")\n", "inference_model.resize_token_embeddings(len(tokenizer))\n", "\n", "inference_model = PeftModel.from_pretrained(inference_model, \"smangrul/mistral_lora_clm_with_added_tokens\")\n", "inference_model.to(\"cuda\")\n", "inference_model.eval()\n", "\n", "output_tokens = inference_model.generate(\n", " **batch,\n", " max_new_tokens=256,\n", " do_sample=True,\n", " temperature=0.2,\n", " top_p=0.95,\n", " top_k=50,\n", " eos_token_id=tokenizer.eos_token_id,\n", " pad_token_id=tokenizer.pad_token_id,\n", ")\n", "\n", "target_predicted = tokenizer.decode(output_tokens[0], skip_special_tokens=False).split(\"<|endcontext|>\")[1]\n", "print(f\"{context=} \\n\\n {target_predicted=} \\n\\n {target=}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "fd57f6e8-761f-4e0b-941c-f6973e13b186", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }