diff --git "a/StarCoder2_3b_4bit.ipynb" "b/StarCoder2_3b_4bit.ipynb" new file mode 100644--- /dev/null +++ "b/StarCoder2_3b_4bit.ipynb" @@ -0,0 +1,6829 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Er_P6ziRGCXH", + "outputId": "8ae9ff41-a91c-46e5-8dc1-8a83d336acae" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Cloning into 'starcoder2'...\n", + "remote: Enumerating objects: 44, done.\u001b[K\n", + "remote: Counting objects: 100% (44/44), done.\u001b[K\n", + "remote: Compressing objects: 100% (41/41), done.\u001b[K\n", + "remote: Total 44 (delta 19), reused 9 (delta 2), pack-reused 0\u001b[K\n", + "Receiving objects: 100% (44/44), 21.08 KiB | 3.01 MiB/s, done.\n", + "Resolving deltas: 100% (19/19), done.\n" + ] + } + ], + "source": [ + "!git clone https://github.com/bigcode-project/starcoder2.git" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mH683TEF08Gs", + "outputId": "4b49252e-05dd-4f2f-e6f1-acb4c5358bdb" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "sample_data starcoder2\n", + "/content/starcoder2\n" + ] + } + ], + "source": [ + "!ls\n", + "%cd starcoder2" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ZnQuOTQCGjZb", + "outputId": "ea9302d5-e401-4980-a350-223bb88c3d79" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting git+https://github.com/huggingface/transformers.git (from -r requirements.txt (line 1))\n", + " Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-r9ck654w\n", + " Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-r9ck654w\n", + " Resolved https://github.com/huggingface/transformers.git to commit 76a33a10923ccc1074917f6b6a1e719e626b7dc9\n", + " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", + " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", + " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting accelerate==0.27.1 (from -r requirements.txt (line 2))\n", + " Downloading accelerate-0.27.1-py3-none-any.whl (279 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m279.7/279.7 kB\u001b[0m \u001b[31m6.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting datasets>=2.16.1 (from -r requirements.txt (line 3))\n", + " Downloading datasets-2.18.0-py3-none-any.whl (510 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m510.5/510.5 kB\u001b[0m \u001b[31m13.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting bitsandbytes==0.41.3 (from -r requirements.txt (line 4))\n", + " Downloading bitsandbytes-0.41.3-py3-none-any.whl (92.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m92.6/92.6 MB\u001b[0m \u001b[31m9.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting peft==0.8.2 (from -r requirements.txt (line 5))\n", + " Downloading peft-0.8.2-py3-none-any.whl (183 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m183.4/183.4 kB\u001b[0m \u001b[31m23.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting trl==0.7.10 (from -r requirements.txt (line 6))\n", + " Downloading trl-0.7.10-py3-none-any.whl (150 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m150.9/150.9 kB\u001b[0m \u001b[31m17.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting wandb==0.16.3 (from -r requirements.txt (line 7))\n", + " Downloading wandb-0.16.3-py3-none-any.whl (2.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.2/2.2 MB\u001b[0m \u001b[31m86.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: huggingface_hub==0.20.3 in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 8)) (0.20.3)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from accelerate==0.27.1->-r requirements.txt (line 2)) (1.25.2)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from accelerate==0.27.1->-r requirements.txt (line 2)) (24.0)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate==0.27.1->-r requirements.txt (line 2)) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from accelerate==0.27.1->-r requirements.txt (line 2)) (6.0.1)\n", + "Requirement already satisfied: torch>=1.10.0 in /usr/local/lib/python3.10/dist-packages (from accelerate==0.27.1->-r requirements.txt (line 2)) (2.2.1+cu121)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from accelerate==0.27.1->-r requirements.txt (line 2)) (0.4.2)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft==0.8.2->-r requirements.txt (line 5)) (4.66.2)\n", + "Collecting tyro>=0.5.11 (from trl==0.7.10->-r requirements.txt (line 6))\n", + " Downloading tyro-0.7.3-py3-none-any.whl (79 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m79.8/79.8 kB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3->-r requirements.txt (line 7)) (8.1.7)\n", + "Collecting GitPython!=3.1.29,>=1.0.0 (from wandb==0.16.3->-r requirements.txt (line 7))\n", + " Downloading GitPython-3.1.42-py3-none-any.whl (195 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m195.4/195.4 kB\u001b[0m \u001b[31m19.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3->-r requirements.txt (line 7)) (2.31.0)\n", + "Collecting sentry-sdk>=1.0.0 (from wandb==0.16.3->-r requirements.txt (line 7))\n", + " Downloading sentry_sdk-1.43.0-py2.py3-none-any.whl (264 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m264.6/264.6 kB\u001b[0m \u001b[31m14.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb==0.16.3->-r requirements.txt (line 7))\n", + " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", + "Collecting setproctitle (from wandb==0.16.3->-r requirements.txt (line 7))\n", + " Downloading setproctitle-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3->-r requirements.txt (line 7)) (67.7.2)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3->-r requirements.txt (line 7)) (1.4.4)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb==0.16.3->-r requirements.txt (line 7)) (3.20.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3->-r requirements.txt (line 8)) (3.13.1)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3->-r requirements.txt (line 8)) (2023.6.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub==0.20.3->-r requirements.txt (line 8)) (4.10.0)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0->-r requirements.txt (line 1)) (2023.12.25)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers==4.40.0.dev0->-r requirements.txt (line 1)) (0.15.2)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.1->-r requirements.txt (line 3)) (14.0.2)\n", + "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.1->-r requirements.txt (line 3)) (0.6)\n", + "Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.16.1->-r requirements.txt (line 3))\n", + " Downloading dill-0.3.8-py3-none-any.whl (116 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m15.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.1->-r requirements.txt (line 3)) (1.5.3)\n", + "Collecting xxhash (from datasets>=2.16.1->-r requirements.txt (line 3))\n", + " Downloading xxhash-3.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m22.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting multiprocess (from datasets>=2.16.1->-r requirements.txt (line 3))\n", + " Downloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m16.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets>=2.16.1->-r requirements.txt (line 3)) (3.9.3)\n", + "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb==0.16.3->-r requirements.txt (line 7)) (1.16.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.1->-r requirements.txt (line 3)) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.1->-r requirements.txt (line 3)) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.1->-r requirements.txt (line 3)) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.1->-r requirements.txt (line 3)) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.1->-r requirements.txt (line 3)) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets>=2.16.1->-r requirements.txt (line 3)) (4.0.3)\n", + "Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb==0.16.3->-r requirements.txt (line 7))\n", + " Downloading gitdb-4.0.11-py3-none-any.whl (62 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb==0.16.3->-r requirements.txt (line 7)) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb==0.16.3->-r requirements.txt (line 7)) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb==0.16.3->-r requirements.txt (line 7)) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb==0.16.3->-r requirements.txt (line 7)) (2024.2.2)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2)) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2)) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2)) (3.1.3)\n", + "Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m23.7/23.7 MB\u001b[0m \u001b[31m28.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m823.6/823.6 kB\u001b[0m \u001b[31m50.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m14.1/14.1 MB\u001b[0m \u001b[31m58.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m731.7/731.7 MB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cublas-cu12==12.1.3.1 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m410.6/410.6 MB\u001b[0m \u001b[31m2.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cufft-cu12==11.0.2.54 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m121.6/121.6 MB\u001b[0m \u001b[31m8.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-curand-cu12==10.3.2.106 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_curand_cu12-10.3.2.106-py3-none-manylinux1_x86_64.whl (56.5 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.5/56.5 MB\u001b[0m \u001b[31m14.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cusolver-cu12==11.4.5.107 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl (124.2 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m124.2/124.2 MB\u001b[0m \u001b[31m8.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-cusparse-cu12==12.1.0.106 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl (196.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m196.0/196.0 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-nccl-cu12==2.19.3 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_nccl_cu12-2.19.3-py3-none-manylinux1_x86_64.whl (166.0 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m166.0/166.0 MB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting nvidia-nvtx-cu12==12.1.105 (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_nvtx_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (99 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m99.1/99.1 kB\u001b[0m \u001b[31m13.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2)) (2.2.0)\n", + "Collecting nvidia-nvjitlink-cu12 (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2))\n", + " Downloading nvidia_nvjitlink_cu12-12.4.99-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m71.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting docstring-parser>=0.14.1 (from tyro>=0.5.11->trl==0.7.10->-r requirements.txt (line 6))\n", + " Downloading docstring_parser-0.16-py3-none-any.whl (36 kB)\n", + "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl==0.7.10->-r requirements.txt (line 6)) (13.7.1)\n", + "Collecting shtab>=1.5.6 (from tyro>=0.5.11->trl==0.7.10->-r requirements.txt (line 6))\n", + " Downloading shtab-1.7.1-py3-none-any.whl (14 kB)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets>=2.16.1->-r requirements.txt (line 3)) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets>=2.16.1->-r requirements.txt (line 3)) (2023.4)\n", + "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb==0.16.3->-r requirements.txt (line 7))\n", + " Downloading smmap-5.0.1-py3-none-any.whl (24 kB)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl==0.7.10->-r requirements.txt (line 6)) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl==0.7.10->-r requirements.txt (line 6)) (2.16.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2)) (2.1.5)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.10.0->accelerate==0.27.1->-r requirements.txt (line 2)) (1.3.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl==0.7.10->-r requirements.txt (line 6)) (0.1.2)\n", + "Building wheels for collected packages: transformers\n", + " Building wheel for transformers (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for transformers: filename=transformers-4.40.0.dev0-py3-none-any.whl size=8802691 sha256=114bda17deca705bc7d115c456c221930feb03621c99e3cdfef186e3adbcb645\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-85oztzfi/wheels/e7/9c/5b/e1a9c8007c343041e61cc484433d512ea9274272e3fcbe7c16\n", + "Successfully built transformers\n", + "Installing collected packages: bitsandbytes, xxhash, smmap, shtab, setproctitle, sentry-sdk, nvidia-nvtx-cu12, nvidia-nvjitlink-cu12, nvidia-nccl-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, docstring-parser, docker-pycreds, dill, nvidia-cusparse-cu12, nvidia-cudnn-cu12, multiprocess, gitdb, tyro, nvidia-cusolver-cu12, GitPython, wandb, transformers, datasets, accelerate, trl, peft\n", + " Attempting uninstall: transformers\n", + " Found existing installation: transformers 4.38.2\n", + " Uninstalling transformers-4.38.2:\n", + " Successfully uninstalled transformers-4.38.2\n", + "Successfully installed GitPython-3.1.42 accelerate-0.27.1 bitsandbytes-0.41.3 datasets-2.18.0 dill-0.3.8 docker-pycreds-0.4.0 docstring-parser-0.16 gitdb-4.0.11 multiprocess-0.70.16 nvidia-cublas-cu12-12.1.3.1 nvidia-cuda-cupti-cu12-12.1.105 nvidia-cuda-nvrtc-cu12-12.1.105 nvidia-cuda-runtime-cu12-12.1.105 nvidia-cudnn-cu12-8.9.2.26 nvidia-cufft-cu12-11.0.2.54 nvidia-curand-cu12-10.3.2.106 nvidia-cusolver-cu12-11.4.5.107 nvidia-cusparse-cu12-12.1.0.106 nvidia-nccl-cu12-2.19.3 nvidia-nvjitlink-cu12-12.4.99 nvidia-nvtx-cu12-12.1.105 peft-0.8.2 sentry-sdk-1.43.0 setproctitle-1.3.3 shtab-1.7.1 smmap-5.0.1 transformers-4.40.0.dev0 trl-0.7.10 tyro-0.7.3 wandb-0.16.3 xxhash-3.4.1\n" + ] + } + ], + "source": [ + "!pip install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 556, + "referenced_widgets": [ + "a59be920211d4a629619061b4156604e", + "320abe23abd243788fb5b6e3a4219f06", + "133bc912b837493fb54bcb224939fb56", + "6b6f486624974118a1fcd249f78b7ac5", + "4f749f4873c9419889088a9d79a606cc", + "fcc71180512d45158f2c0db496d72bb6", + "5dd0f7bed15347e5a46a3d1e4974cb6e", + "88620da195a14fc68cd3cf20f7064600", + "c312e8901ac448eca80c78f7850024e9", + "434f4c372b4142f2a140e1a671a79d76", + "7ea5aaa072e1460ca5d179e897d7bd38", + "44dfb705d7c3486ea41dd6b6a721ecd1", + "e25d1f4bede2466ba97a2da67e131efd", + "7e4119bd74a24c59939ae896aa67c2ad", + "012e4ef36f2d4b2bb5d0cc3c3f8702f8", + "850c9c9e75014008865a3d7121866ba1", + "c8e15e15c2484c39931ff5084297022b", + "814b7d9270a94c319c5d3fb6aba85318", + "ff0101c1a0b04eaea3302067f29abbc8", + "5655f94fbbf8458fa963f3affac43413", + "caec5b7b467d4aaa81ce6ca92e341549", + "50bd077fd7b34a7aa165f365063f8ebd", + "4070a4ddb21649a8813f2556d836a27c", + "7a7baf9ca4e44624a86bab3458d74757", + "fb4581ad59f04d6cba500249524f5a00", + "dd1d337fa7c2432e972d58c2e8255886", + "b77f17d2fa704fa7a4cb1cedd75ec12b", + "f68f09ccc9034bb9842e0da39c941f95", + "c809041f7810415f88442e7689f96c73", + "608b31149b214773a7399350b5318c32", + "b4bcb1d633924c41906fb3b0d1de8ee8", + "2913faca6d6141aa8aa15529ab929a41", + "cc70d9dbc1674f8e9bfd441378f44476", + "bfd830a4eded4bd6ac5ff67a989624e6", + "6deb2e2173c346888b6083b18a0280be", + "672b4410650a43d3822e06cafc32f4c5", + "2da52e84dca94973990312dfeed62782", + "20e48fa6dcff4ccfbb0cc46818c730a4", + "4410646c09b84eff9fdc256c40892ab0", + "b43134cd3e574585ae7d58df14456e6f", + "c2852cfe02f84853a1f4bc179160b82c", + "7ca267b28d97439093601434ce722f62", + "cf11f31c9a764d4481b538d1aecb3710", + "922a5c3b91934880b7c6d9a6b0e66252", + "405322bdc22648bc921fe2d5b4c60c53", + "69ee2ed369ea433286d48a0e20d02425", + "003952586a284ecfa007d1995f4186a2", + "5f15255434ab44a88aff781fc02159a3", + "c30b3aa28bce414798eeb24fdd535330", + "aaf3473056d74a4bb987d902ff550fd4", + "d5b717e07025492bbce5427f20c2961b", + "73b45109768b4b529dd97655fda88bc8", + "2b6b54b1c23a4b9884434b8d87fd24f2", + "8ee0bacd54584486b43693c4c43e2f53", + "827109d7e9ce4f8dac31f790de3ba880", + "5ad0db860358493ea4b0a142cebdb933", + "1f973b3f42e14f4abe3269803c592281", + "d93576400e224fe3af5c134add62380d", + "7d796cbf4581445bb06d74cfb3fcc757", + "c6ee7f2f02d74555a00c2a90d25414e1", + "2e2d3646da404372afdbdafb90bf1355", + "1eee2f3d00494253a94539ebf4b7f678", + "7e8525b90faa4766acc656281f41572b", + "fa3c4bebab8547eb8d3faa667d580692", + "42d23b11dc9a4d1f9450943f4e373968", + "2c7465c4ae394811869acefb4a4befe0", + "b14725886e7c48d0ad56016b172fff4c", + "ffa938c3d01541bab8d70cccab998b08", + "1b7234b5168b423eb72b46a7e7cbe525", + "c379c78ae8394a8c9500f0ec38510816", + "88c833d88e5a4c7ca59af3e08406924b", + "3dd021ffac1b45dea1f9b0cff3e02543", + "029565ee66b3446eae898171e1c6e834", + "eaafdfff4bae46668ac5631ffeda2cdc", + "b788169ede0845fea71535554ad10a67", + "04100a11680640df9e68246fbe5d0451", + "7dba004196934bee9cb0c766da40e68e" + ] + }, + "id": "sSwkDbbaHK3E", + "outputId": "f6d31671-8e28-4b63-9739-1338d80e4d25" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:88: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a59be920211d4a629619061b4156604e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "tokenizer_config.json: 0%| | 0.00/7.88k [00:00 {\n function loadScript(url) {\n return new Promise(function(resolve, reject) {\n let newScript = document.createElement(\"script\");\n newScript.onerror = reject;\n newScript.onload = resolve;\n document.body.appendChild(newScript);\n newScript.src = url;\n });\n }\n loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n const iframe = document.createElement('iframe')\n iframe.style.cssText = \"width:0;height:0;border:none\"\n document.body.appendChild(iframe)\n const handshake = new Postmate({\n container: iframe,\n url: 'https://wandb.ai/authorize'\n });\n const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n handshake.then(function(child) {\n child.on('authorize', data => {\n clearTimeout(timeout)\n resolve(data)\n });\n });\n })\n });\n ", + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n", + "wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " ··········\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" + ] + }, + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import wandb\n", + "wandb.login()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tFIqEZpCJVNe", + "outputId": "32755bb1-2067-4708-a5a9-4c4a10ce6f64" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Token will not been saved to git credential helper. Pass `add_to_git_credential=True` if you want to set the git credential as well.\n", + "Token is valid (permission: write).\n", + "Your token has been saved to /root/.cache/huggingface/token\n", + "Login successful\n" + ] + } + ], + "source": [ + "from huggingface_hub import notebook_login\n", + "\n", + "notebook_login()" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 234, + "referenced_widgets": [ + "309e953b62954bacb8922437d257a1f1", + "b5f237544dde45c59de10da182608941", + "dc90e05c4f7045d7aea863618ccadad2", + "584d7ec353ff4722a2d77585f4839115", + "9496fb1192374db7943f79c5a232b2b4", + "6245c9499721471881b6a0265c43be8e", + "4941882fb90244d39a12f183aa92da25", + "cc7b43477eb7408eb4de24392a2a907f", + "2096c7fe4d6c4a7eb313a0d95d2e9844", + "d1858dc2ed284af6ac41446b7f07ea6e", + "22092f32014e438fa73fa880a6ec600c", + "80a83bae2f3343aea7477a08a571fb1b", + "a716af012fdf470c9f08baa9efd6454e", + "cddef4db899849adac12bf086f49f01a", + "afc9a06ecc16456db6f624fed7ecfeb6", + "acda9da77e7444b8bd769f1ed61f95b5", + "7a0bdb9dd604436ab77bf3681960f70f", + "eceb2ad161f14c159cda898e883ea239", + "f61726660f4b4b898e05ed9d7b54209b", + "f74eb867afc842dda1642c4c7066a908", + "67aca1d36bf04f2d8ab56a6c597d3eb7", + "29132b57a52746eb9e51f5f6aa6bd05b", + "d464effbf7f34d0fa7b846af3a7f9af0", + "9bcfdcfcfc6e4652bed5ff3b99336040", + "4d24d968f2ad44c0a157b4e2c0476f1d", + "7bf72e9ea3cc4231b62b1a3f201724c1", + "abb1ccb4fa7042939d395bfbb2361f42", + "b2787ecb49cf44fe8b3510981c5cd659", + "d7d4a84bf52540fda77e1d3e6c4d376e", + "9706ef6c7a3147e9bc2805cd5e6ac2ae", + "e58a2c12f9ba4b599e0bff5c3d9ce561", + "6b7b0b1d545a4aef986d29feedda7f4d", + "bf000603106047daa85b3caf2e927d40" + ] + }, + "id": "0b0QYsoCNOGm", + "outputId": "59156f6c-b686-40fe-b267-dada1e94825f" + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "309e953b62954bacb8922437d257a1f1", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Downloading readme: 0%| | 0.00/3.30k [00:00=1.4.0 in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (2.2.1+cu121)\n", + "Requirement already satisfied: transformers>=4.31.0 in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (4.40.0.dev0)\n", + "Requirement already satisfied: numpy>=1.18.2 in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (1.25.2)\n", + "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (0.27.1)\n", + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (2.18.0)\n", + "Requirement already satisfied: tyro>=0.5.11 in /usr/local/lib/python3.10/dist-packages (from trl==0.8.2.dev0) (0.7.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (3.13.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (4.10.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (3.1.3)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (2023.6.0)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (2.19.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (12.1.105)\n", + "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.4.0->trl==0.8.2.dev0) (2.2.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.4.0->trl==0.8.2.dev0) (12.4.99)\n", + "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (0.20.3)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (24.0)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (6.0.1)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (2023.12.25)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (2.31.0)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (0.15.2)\n", + "Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (0.4.2)\n", + "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers>=4.31.0->trl==0.8.2.dev0) (4.66.2)\n", + "Requirement already satisfied: docstring-parser>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl==0.8.2.dev0) (0.16)\n", + "Requirement already satisfied: rich>=11.1.0 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl==0.8.2.dev0) (13.7.1)\n", + "Requirement already satisfied: shtab>=1.5.6 in /usr/local/lib/python3.10/dist-packages (from tyro>=0.5.11->trl==0.8.2.dev0) (1.7.1)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate->trl==0.8.2.dev0) (5.9.5)\n", + "Requirement already satisfied: pyarrow>=12.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (14.0.2)\n", + "Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (0.6)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (1.5.3)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (0.70.16)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets->trl==0.8.2.dev0) (3.9.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (23.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (1.4.1)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (6.0.5)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (1.9.4)\n", + "Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets->trl==0.8.2.dev0) (4.0.3)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl==0.8.2.dev0) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl==0.8.2.dev0) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl==0.8.2.dev0) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers>=4.31.0->trl==0.8.2.dev0) (2024.2.2)\n", + "Requirement already satisfied: markdown-it-py>=2.2.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl==0.8.2.dev0) (3.0.0)\n", + "Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /usr/local/lib/python3.10/dist-packages (from rich>=11.1.0->tyro>=0.5.11->trl==0.8.2.dev0) (2.16.1)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.4.0->trl==0.8.2.dev0) (2.1.5)\n", + "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl==0.8.2.dev0) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets->trl==0.8.2.dev0) (2023.4)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.4.0->trl==0.8.2.dev0) (1.3.0)\n", + "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py>=2.2.0->rich>=11.1.0->tyro>=0.5.11->trl==0.8.2.dev0) (0.1.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets->trl==0.8.2.dev0) (1.16.0)\n", + "Building wheels for collected packages: trl\n", + " Building wheel for trl (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for trl: filename=trl-0.8.2.dev0-py3-none-any.whl size=237720 sha256=b18a84474ea863787fb9781dea9dd28a2a8b13faaa0d4075b0edc8de500d0ac1\n", + " Stored in directory: /tmp/pip-ephem-wheel-cache-sv1ib3mn/wheels/22/0e/42/319b77b2648bb6140ef2b08b0478ede9ca3cc7879fcd022d36\n", + "Successfully built trl\n", + "Installing collected packages: trl\n", + " Attempting uninstall: trl\n", + " Found existing installation: trl 0.7.10\n", + " Uninstalling trl-0.7.10:\n", + " Successfully uninstalled trl-0.7.10\n", + "Successfully installed trl-0.8.2.dev0\n", + "Requirement already satisfied: peft in /usr/local/lib/python3.10/dist-packages (0.8.2)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from peft) (1.25.2)\n", + "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from peft) (24.0)\n", + "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from peft) (5.9.5)\n", + "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from peft) (6.0.1)\n", + "Requirement already satisfied: torch>=1.13.0 in /usr/local/lib/python3.10/dist-packages (from peft) (2.2.1+cu121)\n", + "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from peft) (4.40.0.dev0)\n", + "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from peft) (4.66.2)\n", + "Requirement already satisfied: accelerate>=0.21.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.27.1)\n", + "Requirement already satisfied: safetensors in /usr/local/lib/python3.10/dist-packages (from peft) (0.4.2)\n", + "Requirement already satisfied: huggingface-hub>=0.17.0 in /usr/local/lib/python3.10/dist-packages (from peft) (0.20.3)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (3.13.1)\n", + "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (2023.6.0)\n", + "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (2.31.0)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.17.0->peft) (4.10.0)\n", + "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (1.12)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.2.1)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (3.1.3)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.19.3 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.19.3)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (12.1.105)\n", + "Requirement already satisfied: triton==2.2.0 in /usr/local/lib/python3.10/dist-packages (from torch>=1.13.0->peft) (2.2.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /usr/local/lib/python3.10/dist-packages (from nvidia-cusolver-cu12==11.4.5.107->torch>=1.13.0->peft) (12.4.99)\n", + "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (2023.12.25)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /usr/local/lib/python3.10/dist-packages (from transformers->peft) (0.15.2)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch>=1.13.0->peft) (2.1.5)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (2.0.7)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub>=0.17.0->peft) (2024.2.2)\n", + "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch>=1.13.0->peft) (1.3.0)\n" + ] + } + ], + "source": [ + "!pip install git+https://github.com/huggingface/trl.git\n", + "!pip install peft" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "8vN-eLgVQryO" + }, + "outputs": [], + "source": [ + "import argparse\n", + "import multiprocessing\n", + "import os\n", + "\n", + "import torch\n", + "import transformers\n", + "from accelerate import PartialState\n", + "from datasets import load_dataset\n", + "from peft import LoraConfig\n", + "from transformers import (\n", + " AutoModelForCausalLM,\n", + " BitsAndBytesConfig,\n", + " logging,\n", + " set_seed,\n", + ")\n", + "from trl import SFTTrainer" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "WGUK5keeTHoO" + }, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser()\n", + "parser.add_argument(\"--model_id\", type=str, default=\"bigcode/starcoder2-3b\")\n", + "parser.add_argument(\"--dataset_name\", type=str, default=\"bigcode/the-stack-smol\")\n", + "parser.add_argument(\"--subset\", type=str, default=\"data/php\")\n", + "parser.add_argument(\"--split\", type=str, default=\"train\")\n", + "parser.add_argument(\"--dataset_text_field\", type=str, default=\"content\")\n", + "\n", + "parser.add_argument(\"--max_seq_length\", type=int, default=512)\n", + "parser.add_argument(\"--max_steps\", type=int, default=1000)\n", + "parser.add_argument(\"--micro_batch_size\", type=int, default=1)\n", + "parser.add_argument(\"--gradient_accumulation_steps\", type=int, default=4)\n", + "parser.add_argument(\"--weight_decay\", type=float, default=0.01)\n", + "parser.add_argument(\"--fp16\", type=bool, default=True)\n", + "\n", + "parser.add_argument(\"--attention_dropout\", type=float, default=0.1)\n", + "parser.add_argument(\"--learning_rate\", type=float, default=2e-4)\n", + "parser.add_argument(\"--lr_scheduler_type\", type=str, default=\"cosine\")\n", + "parser.add_argument(\"--warmup_steps\", type=int, default=100)\n", + "parser.add_argument(\"--seed\", type=int, default=0)\n", + "parser.add_argument(\"--output_dir\", type=str, default=\"finetunedPHP_starcoder2\")\n", + "parser.add_argument(\"--num_proc\", type=int, default=2)#T4 gpu of google colab\n", + "parser.add_argument(\"--push_to_hub\", type=bool, default=True)\n", + "\n", + "args = parser.parse_args(args=[])" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "sRJpl3NoTUqk", + "outputId": "52c0e1ed-8a22-469a-c730-9fcd212960aa" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "Namespace(model_id='bigcode/starcoder2-3b', dataset_name='bigcode/the-stack-smol', subset='data/php', split='train', dataset_text_field='content', max_seq_length=512, max_steps=1000, micro_batch_size=1, gradient_accumulation_steps=4, weight_decay=0.01, fp16=True, attention_dropout=0.1, learning_rate=0.0002, lr_scheduler_type='cosine', warmup_steps=100, seed=0, output_dir='finetunedPHP_starcoder2', num_proc=2, push_to_hub=True)" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "args" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "WEJKhbrOTWwW" + }, + "outputs": [], + "source": [ + "def print_trainable_parameters(model):\n", + " \"\"\"\n", + " Prints the number of trainable parameters in the model.\n", + " \"\"\"\n", + " trainable_params = 0\n", + " all_param = 0\n", + " for _, param in model.named_parameters():\n", + " all_param += param.numel()\n", + " if param.requires_grad:\n", + " trainable_params += param.numel()\n", + " print(\n", + " f\"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "id": "Dd0D-RwtTaxG" + }, + "outputs": [], + "source": [ + "def main(args):\n", + " # config\n", + " bnb_config = BitsAndBytesConfig(\n", + " load_in_4bit=True,\n", + " bnb_4bit_quant_type=\"nf4\",\n", + " bnb_4bit_compute_dtype=torch.bfloat16,\n", + " )\n", + " lora_config = LoraConfig(\n", + " r=8,\n", + " target_modules=[\n", + " \"q_proj\",\n", + " \"o_proj\",\n", + " \"k_proj\",\n", + " \"v_proj\",\n", + " \"gate_proj\",\n", + " \"up_proj\",\n", + " \"down_proj\",\n", + " ],\n", + " task_type=\"CAUSAL_LM\",\n", + " )\n", + "\n", + " # load model and dataset\n", + " token = os.environ.get(\"HF_TOKEN\", None)\n", + " model = AutoModelForCausalLM.from_pretrained(\n", + " args.model_id,\n", + " quantization_config=bnb_config,\n", + " device_map={\"\": PartialState().process_index},\n", + " attention_dropout=args.attention_dropout,\n", + " )\n", + " print_trainable_parameters(model)\n", + "\n", + " data = load_dataset(\n", + " args.dataset_name,\n", + " data_dir=args.subset,\n", + " split=args.split,\n", + " token=token,\n", + " num_proc=args.num_proc if args.num_proc else multiprocessing.cpu_count(),\n", + " )\n", + "\n", + " # setup the trainer\n", + " trainer = SFTTrainer(\n", + " model=model,\n", + " train_dataset=data,\n", + " max_seq_length=args.max_seq_length,\n", + " args=transformers.TrainingArguments(\n", + " per_device_train_batch_size=args.micro_batch_size,\n", + " gradient_accumulation_steps=args.gradient_accumulation_steps,\n", + " warmup_steps=args.warmup_steps,\n", + " max_steps=args.max_steps,\n", + " learning_rate=args.learning_rate,\n", + " lr_scheduler_type=args.lr_scheduler_type,\n", + " weight_decay=args.weight_decay,\n", + " fp16=args.fp16,\n", + " logging_strategy=\"steps\",\n", + " logging_steps=10,\n", + " output_dir=args.output_dir,\n", + " optim=\"paged_adamw_8bit\",\n", + " seed=args.seed,\n", + " run_name=f\"train-{args.model_id.split('/')[-1]}\",\n", + " report_to=\"wandb\",\n", + " ),\n", + " peft_config=lora_config,\n", + " dataset_text_field=args.dataset_text_field,\n", + " )\n", + "\n", + " # launch\n", + " print(\"Training...\")\n", + " trainer.train()\n", + "\n", + " print(\"Saving the last checkpoint of the model\")\n", + " model.save_pretrained(os.path.join(args.output_dir, \"final_checkpoint/\"))\n", + " if args.push_to_hub:\n", + " trainer.push_to_hub(\"Upload model\")\n", + " print(\"Training Done! 💥\")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "72eb5b869a3d4208be0794d88cf1273a", + "87327145b663471cab5703d1096091c3", + "8170426464f047b2b52d35ce3f2e7fe4", + "2150e977ba5148d4be5c36c296ce58e7", + "7e083492d3624852b09180109dc1f61c", + "19d1edff57b347f7a8899bb4eb67bd9c", + "69e488dd84264520a1fc58d78f5f4b0e", + "b33305784bca4e7db6ba0baca52a657d", + "4db741b884a746aa9b4bfba82f9604a9", + "178e36c3ecac4b06acd2b85553722186", + "136e00525e534b4babd69f26354cbfc6", + "70174cbd47fd4651b21b2e37a3d67e74", + "72f5fbee99084a059855e681a3d3c4e7", + "b601b8b86ee440009d42db30e6f9a81a", + "5411fe75bd6743e69dd5f36399ca22a5", + "9660db8962184e6d90dfb6e63c4a67a1", + "bf2b0ed1fe0245b096941541cdd37d04", + "e9c1271507a54232b7f6412061cd0277", + "cfcc35cb26dc4316bf7e55932bc83d9e", + "9795872f83e04e0aa5ee8b894456b690", + "b9b623710b42439196150a2fc0faec60", + "038e5fc62ce14d4cbba1b081c176384a", + "c073c5f0bfdf4fd19348b7acb2a7aca1", + "7df2b353fe7d43dc87b050b7c7b4db58", + "00857252102741ada96bb9b30c2d5e51", + "55a5a5c8e20a4e64b5b3111f2a1b46aa", + "d0e9a0d4f56d4b16bce30e732233caae", + "f3280a91b6d94b0698effd4651090848", + "462bc9ef802149f8a933df97825ba948", + "46b733575dab43e789fced023aa02039", + "d46ae17e30864685aabeeee2f5c87510", + "b762205684ca4f46bcb3c35ac3b0169a", + "b1cedb47cc1d498ab1df16bb7be92c3d", + "ee0008017f364de4a846634035867bf6", + "fcca509c05d1406fa7958b7684df5a5a", + "7cb7112c3a724898ab5622d89c9f8f93", + "8abb8cf8180c42ba829f86736e4bb75f", + "efcc10a43a304619aa47b4ccef1ec040", + "05ee7576463e4458aa5de62c84ff2bcc", + "8af91fd8a896407abbe4b64c6e03213d", + "7d563aac5fad451abc083a3a24355b80", + "58d3d858345b49d2b286809d2a28ad56", + "a9b9c3558f7b411484467fcd61829ff0", + "e543f37c864c45ca8037a6eaa64d384a", + "82bd138f622e4d1bb72caaaf04023826", + "5d1245a75ab4428982ce0492c243ed5d", + "78c6ccd604ad484aa05679cce35cbcb2", + "9844490d9ff447019326ae4ba7c8eca1", + "281a882ba6834f11810950e12c26ceb3", + "1e6d12d1f8704a558842a6ec2cee3473", + "4fe72209ab7b40c5b2fde5177b4e8091", + "4b90abf81ee8469095ad761bd4533281", + "c260982c5dec4fab8197da28d59e08de", + "9d91d176e0eb47f388279333b1d62b12", + "b071c849108344c187207f9d0451438c", + "1a101882a5a24291b776c78f55365ea1", + "4868fa4cc27f455e9cbfc85d149b346f", + "b0c693b8aac94dc48a7753ab275ac11c", + "f341ea3624db4035b761e6fa402146ce", + "9b3faf1529354f8e8d3ab13d31fad9a9", + "4c0f33db15124a80ae0a5e901b6f487b", + "01a767b253f2415d801c1ec83a2098de", + "7b007e8d8ffb4de6b3a5f3250502e466" + ] + }, + "id": "GjuYlUNdTqnq", + "outputId": "8e9ed1ef-55d4-48cb-b59f-22e3c0de3fe8" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "trainable params: 151369728 || all params: 1591200768 || trainable%: 9.5129245186488\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "72eb5b869a3d4208be0794d88cf1273a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Map: 0%| | 0/10000 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.16.3" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in /content/starcoder2/wandb/run-20240325_044236-yaifaow3" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run train-starcoder2-3b to Weights & Biases (docs)
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/starcoder2/huggingface" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/starcoder2/huggingface/runs/yaifaow3" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'loss': 2.1512, 'grad_norm': 2.374952793121338, 'learning_rate': 2e-05, 'epoch': 0.0}\n", + "{'loss': 2.8922, 'grad_norm': 1.7982090711593628, 'learning_rate': 3.8e-05, 'epoch': 0.01}\n", + "{'loss': 2.9155, 'grad_norm': 2.2329304218292236, 'learning_rate': 5.4000000000000005e-05, 'epoch': 0.01}\n", + "{'loss': 2.7554, 'grad_norm': 1.7585422992706299, 'learning_rate': 7.4e-05, 'epoch': 0.02}\n", + "{'loss': 1.9353, 'grad_norm': 0.5500292181968689, 'learning_rate': 9.200000000000001e-05, 'epoch': 0.02}\n", + "{'loss': 1.9264, 'grad_norm': 0.4516673684120178, 'learning_rate': 0.00011200000000000001, 'epoch': 0.02}\n", + "{'loss': 1.5593, 'grad_norm': 1.8559635877609253, 'learning_rate': 0.000132, 'epoch': 0.03}\n", + "{'loss': 1.4748, 'grad_norm': 0.9799067974090576, 'learning_rate': 0.000152, 'epoch': 0.03}\n", + "{'loss': 1.5837, 'grad_norm': 2.633228063583374, 'learning_rate': 0.000172, 'epoch': 0.04}\n", + "{'loss': 1.197, 'grad_norm': 2.0126709938049316, 'learning_rate': 0.000192, 'epoch': 0.04}\n", + "{'loss': 1.293, 'grad_norm': 5.246824741363525, 'learning_rate': 0.00019997806834748456, 'epoch': 0.04}\n", + "{'loss': 1.0963, 'grad_norm': 1.4045424461364746, 'learning_rate': 0.00019984407641819812, 'epoch': 0.05}\n", + "{'loss': 1.0916, 'grad_norm': 1.9746116399765015, 'learning_rate': 0.00019958843986159704, 'epoch': 0.05}\n", + "{'loss': 1.0437, 'grad_norm': 0.8059239387512207, 'learning_rate': 0.0001992114701314478, 'epoch': 0.06}\n", + "{'loss': 1.3049, 'grad_norm': 1.7704951763153076, 'learning_rate': 0.0001987136265072988, 'epoch': 0.06}\n", + "{'loss': 1.1942, 'grad_norm': 5.766129970550537, 'learning_rate': 0.00019809551553491916, 'epoch': 0.06}\n", + "{'loss': 1.169, 'grad_norm': 0.8774454593658447, 'learning_rate': 0.00019735789028731604, 'epoch': 0.07}\n", + "{'loss': 1.1191, 'grad_norm': 1.010455846786499, 'learning_rate': 0.00019650164944723115, 'epoch': 0.07}\n", + "{'loss': 1.0834, 'grad_norm': 1.1836800575256348, 'learning_rate': 0.00019552783621223436, 'epoch': 0.08}\n", + "{'loss': 1.1387, 'grad_norm': 1.8022960424423218, 'learning_rate': 0.00019443763702374812, 'epoch': 0.08}\n", + "{'loss': 1.0438, 'grad_norm': 0.9112914800643921, 'learning_rate': 0.00019323238012155123, 'epoch': 0.08}\n", + "{'loss': 1.048, 'grad_norm': 1.0949299335479736, 'learning_rate': 0.00019191353392552344, 'epoch': 0.09}\n", + "{'loss': 1.1616, 'grad_norm': 1.9528100490570068, 'learning_rate': 0.00019048270524660196, 'epoch': 0.09}\n", + "{'loss': 1.0356, 'grad_norm': 1.1788562536239624, 'learning_rate': 0.00018894163732912977, 'epoch': 0.1}\n", + "{'loss': 1.064, 'grad_norm': 1.5582410097122192, 'learning_rate': 0.00018729220772698097, 'epoch': 0.1}\n", + "{'loss': 0.9771, 'grad_norm': 0.9576661586761475, 'learning_rate': 0.00018553642601605068, 'epoch': 0.1}\n", + "{'loss': 0.9917, 'grad_norm': 0.6323869228363037, 'learning_rate': 0.00018367643134589617, 'epoch': 0.11}\n", + "{'loss': 0.9698, 'grad_norm': 0.7207333445549011, 'learning_rate': 0.00018171448983351284, 'epoch': 0.11}\n", + "{'loss': 1.0115, 'grad_norm': 1.595521330833435, 'learning_rate': 0.00017965299180241963, 'epoch': 0.12}\n", + "{'loss': 1.1046, 'grad_norm': 0.9994978308677673, 'learning_rate': 0.00017749444887041799, 'epoch': 0.12}\n", + "{'loss': 1.1536, 'grad_norm': 0.5948280096054077, 'learning_rate': 0.00017524149088957245, 'epoch': 0.12}\n", + "{'loss': 1.0937, 'grad_norm': 0.8551707863807678, 'learning_rate': 0.00017289686274214118, 'epoch': 0.13}\n", + "{'loss': 1.0855, 'grad_norm': 0.748194694519043, 'learning_rate': 0.00017046342099635948, 'epoch': 0.13}\n", + "{'loss': 1.1072, 'grad_norm': 0.8925004005432129, 'learning_rate': 0.00016794413042615168, 'epoch': 0.14}\n", + "{'loss': 1.0105, 'grad_norm': 1.1091358661651611, 'learning_rate': 0.00016534206039901057, 'epoch': 0.14}\n", + "{'loss': 1.0339, 'grad_norm': 0.569215714931488, 'learning_rate': 0.00016266038113644607, 'epoch': 0.14}\n", + "{'loss': 1.1331, 'grad_norm': 0.9173529744148254, 'learning_rate': 0.0001599023598515586, 'epoch': 0.15}\n", + "{'loss': 0.9847, 'grad_norm': 0.8884732127189636, 'learning_rate': 0.0001570713567684432, 'epoch': 0.15}\n", + "{'loss': 1.0464, 'grad_norm': 2.1106643676757812, 'learning_rate': 0.000154170821028274, 'epoch': 0.16}\n", + "{'loss': 1.1202, 'grad_norm': 1.6022874116897583, 'learning_rate': 0.00015120428648705717, 'epoch': 0.16}\n", + "{'loss': 0.9023, 'grad_norm': 0.9261879324913025, 'learning_rate': 0.00014817536741017152, 'epoch': 0.16}\n", + "{'loss': 0.9839, 'grad_norm': 0.6320977807044983, 'learning_rate': 0.00014508775406894307, 'epoch': 0.17}\n", + "{'loss': 0.971, 'grad_norm': 0.6100371479988098, 'learning_rate': 0.00014194520824461771, 'epoch': 0.17}\n", + "{'loss': 0.9614, 'grad_norm': 1.489484429359436, 'learning_rate': 0.0001387515586452103, 'epoch': 0.18}\n", + "{'loss': 0.8534, 'grad_norm': 0.6577572822570801, 'learning_rate': 0.0001355106962408137, 'epoch': 0.18}\n", + "{'loss': 1.0541, 'grad_norm': 1.5166479349136353, 'learning_rate': 0.00013222656952305113, 'epoch': 0.18}\n", + "{'loss': 0.9385, 'grad_norm': 0.9302632212638855, 'learning_rate': 0.00012890317969444716, 'epoch': 0.19}\n", + "{'loss': 1.0154, 'grad_norm': 0.7496092915534973, 'learning_rate': 0.00012554457579357905, 'epoch': 0.19}\n", + "{'loss': 1.4388, 'grad_norm': 11.857297897338867, 'learning_rate': 0.00012215484976194676, 'epoch': 0.2}\n", + "{'loss': 0.9593, 'grad_norm': 1.2743455171585083, 'learning_rate': 0.00011873813145857249, 'epoch': 0.2}\n", + "{'loss': 1.1586, 'grad_norm': 0.5912078619003296, 'learning_rate': 0.00011529858362840382, 'epoch': 0.2}\n", + "{'loss': 0.935, 'grad_norm': 0.8064272403717041, 'learning_rate': 0.00011184039683065013, 'epoch': 0.21}\n", + "{'loss': 0.8872, 'grad_norm': 0.6828736066818237, 'learning_rate': 0.00010836778433323158, 'epoch': 0.21}\n", + "{'loss': 1.0337, 'grad_norm': 0.7725141644477844, 'learning_rate': 0.00010488497697956135, 'epoch': 0.22}\n", + "{'loss': 1.028, 'grad_norm': 1.1346145868301392, 'learning_rate': 0.00010139621803391455, 'epoch': 0.22}\n", + "{'loss': 0.9217, 'grad_norm': 0.5506075024604797, 'learning_rate': 9.790575801166432e-05, 'epoch': 0.22}\n", + "{'loss': 1.0113, 'grad_norm': 1.3045012950897217, 'learning_rate': 9.441784950068362e-05, 'epoch': 0.23}\n", + "{'loss': 0.9007, 'grad_norm': 0.688460111618042, 'learning_rate': 9.093674198022201e-05, 'epoch': 0.23}\n", + "{'loss': 0.8763, 'grad_norm': 1.6210631132125854, 'learning_rate': 8.746667664356956e-05, 'epoch': 0.24}\n", + "{'loss': 0.8863, 'grad_norm': 0.9501083493232727, 'learning_rate': 8.401188123081653e-05, 'epoch': 0.24}\n", + "{'loss': 1.0207, 'grad_norm': 2.010164260864258, 'learning_rate': 8.057656487800282e-05, 'epoch': 0.24}\n", + "{'loss': 1.0258, 'grad_norm': 1.1473491191864014, 'learning_rate': 7.716491298893442e-05, 'epoch': 0.25}\n", + "{'loss': 1.1317, 'grad_norm': 0.6085152626037598, 'learning_rate': 7.378108213591355e-05, 'epoch': 0.25}\n", + "{'loss': 0.9657, 'grad_norm': 0.692514955997467, 'learning_rate': 7.042919499559537e-05, 'epoch': 0.26}\n", + "{'loss': 0.992, 'grad_norm': 0.6227133870124817, 'learning_rate': 6.711333532614168e-05, 'epoch': 0.26}\n", + "{'loss': 0.9975, 'grad_norm': 1.1793878078460693, 'learning_rate': 6.383754299179079e-05, 'epoch': 0.26}\n", + "{'loss': 1.0198, 'grad_norm': 0.6962803602218628, 'learning_rate': 6.0605809040904894e-05, 'epoch': 0.27}\n", + "{'loss': 0.9052, 'grad_norm': 0.8879271149635315, 'learning_rate': 5.7422070843492734e-05, 'epoch': 0.27}\n", + "{'loss': 1.0408, 'grad_norm': 0.7350073456764221, 'learning_rate': 5.4290207294130615e-05, 'epoch': 0.28}\n", + "{'loss': 0.8618, 'grad_norm': 1.4684317111968994, 'learning_rate': 5.121403408612672e-05, 'epoch': 0.28}\n", + "{'loss': 1.2006, 'grad_norm': 0.8536351323127747, 'learning_rate': 4.8197299062686995e-05, 'epoch': 0.28}\n", + "{'loss': 0.8628, 'grad_norm': 1.4125428199768066, 'learning_rate': 4.524367765074499e-05, 'epoch': 0.29}\n", + "{'loss': 0.9095, 'grad_norm': 0.9320453405380249, 'learning_rate': 4.235676838302068e-05, 'epoch': 0.29}\n", + "{'loss': 0.9844, 'grad_norm': 0.8588120937347412, 'learning_rate': 3.954008851376252e-05, 'epoch': 0.3}\n", + "{'loss': 0.9199, 'grad_norm': 0.6745263338088989, 'learning_rate': 3.679706973351491e-05, 'epoch': 0.3}\n", + "{'loss': 0.8081, 'grad_norm': 0.43302035331726074, 'learning_rate': 3.413105398813195e-05, 'epoch': 0.3}\n", + "{'loss': 1.0037, 'grad_norm': 0.7672327160835266, 'learning_rate': 3.154528940713113e-05, 'epoch': 0.31}\n", + "{'loss': 0.898, 'grad_norm': 0.8709444999694824, 'learning_rate': 2.904292634634793e-05, 'epoch': 0.31}\n", + "{'loss': 0.8914, 'grad_norm': 1.3814789056777954, 'learning_rate': 2.6627013549712355e-05, 'epoch': 0.32}\n", + "{'loss': 0.9274, 'grad_norm': 0.49220532178878784, 'learning_rate': 2.4300494434824373e-05, 'epoch': 0.32}\n", + "{'loss': 1.079, 'grad_norm': 1.0098705291748047, 'learning_rate': 2.2066203506852566e-05, 'epoch': 0.32}\n", + "{'loss': 1.0611, 'grad_norm': 0.7284759283065796, 'learning_rate': 1.9926862905126665e-05, 'epoch': 0.33}\n", + "{'loss': 0.9219, 'grad_norm': 3.0494790077209473, 'learning_rate': 1.78850790866296e-05, 'epoch': 0.33}\n", + "{'loss': 0.9271, 'grad_norm': 0.6862832307815552, 'learning_rate': 1.5943339650431576e-05, 'epoch': 0.34}\n", + "{'loss': 0.9586, 'grad_norm': 0.7988976836204529, 'learning_rate': 1.4104010306933557e-05, 'epoch': 0.34}\n", + "{'loss': 0.8254, 'grad_norm': 0.8809404969215393, 'learning_rate': 1.2369331995613665e-05, 'epoch': 0.34}\n", + "{'loss': 0.9002, 'grad_norm': 0.8660016655921936, 'learning_rate': 1.0741418154787442e-05, 'epoch': 0.35}\n", + "{'loss': 1.1073, 'grad_norm': 1.7036733627319336, 'learning_rate': 9.222252146709142e-06, 'epoch': 0.35}\n", + "{'loss': 0.9924, 'grad_norm': 0.7075016498565674, 'learning_rate': 7.81368484114996e-06, 'epoch': 0.36}\n", + "{'loss': 0.9048, 'grad_norm': 1.2110106945037842, 'learning_rate': 6.517432360398556e-06, 'epoch': 0.36}\n", + "{'loss': 0.8631, 'grad_norm': 0.7055613398551941, 'learning_rate': 5.335073988430372e-06, 'epoch': 0.36}\n", + "{'loss': 1.0032, 'grad_norm': 0.7021167278289795, 'learning_rate': 4.268050246793276e-06, 'epoch': 0.37}\n", + "{'loss': 1.0182, 'grad_norm': 2.289949655532837, 'learning_rate': 3.3176611395540626e-06, 'epoch': 0.37}\n", + "{'loss': 0.8698, 'grad_norm': 1.410033106803894, 'learning_rate': 2.4850645694436736e-06, 'epoch': 0.38}\n", + "{'loss': 0.9534, 'grad_norm': 0.9889239072799683, 'learning_rate': 1.771274927131139e-06, 'epoch': 0.38}\n", + "{'loss': 0.9366, 'grad_norm': 0.5054696798324585, 'learning_rate': 1.1771618553447216e-06, 'epoch': 0.38}\n", + "{'loss': 0.8781, 'grad_norm': 1.150292992591858, 'learning_rate': 7.034491893463058e-07, 'epoch': 0.39}\n", + "{'loss': 0.9668, 'grad_norm': 0.7822797894477844, 'learning_rate': 3.50714075049563e-07, 'epoch': 0.39}\n", + "{'loss': 0.9574, 'grad_norm': 1.036647081375122, 'learning_rate': 1.193862658566025e-07, 'epoch': 0.4}\n", + "{'loss': 0.9566, 'grad_norm': 1.1002436876296997, 'learning_rate': 9.747599069576119e-09, 'epoch': 0.4}\n", + "{'train_runtime': 2551.3822, 'train_samples_per_second': 1.568, 'train_steps_per_second': 0.392, 'train_loss': 1.1126603059768676, 'epoch': 0.4}\n", + "Saving the last checkpoint of the model\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "9795872f83e04e0aa5ee8b894456b690", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "adapter_model.safetensors: 0%| | 0.00/18.2M [00:00