diff --git "a/Notebooks/ysda-ml-02-05-finetune.ipynb" "b/Notebooks/ysda-ml-02-05-finetune.ipynb" new file mode 100644--- /dev/null +++ "b/Notebooks/ysda-ml-02-05-finetune.ipynb" @@ -0,0 +1,7487 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "d7e9b191", + "metadata": { + "execution": { + "iopub.execute_input": "2023-04-16T20:10:01.892494Z", + "iopub.status.busy": "2023-04-16T20:10:01.892018Z", + "iopub.status.idle": "2023-04-16T20:10:15.507545Z", + "shell.execute_reply": "2023-04-16T20:10:15.506148Z" + }, + "papermill": { + "duration": 13.624344, + "end_time": "2023-04-16T20:10:15.510606", + "exception": false, + "start_time": "2023-04-16T20:10:01.886262", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting sentence-transformers\r\n", + " Downloading sentence-transformers-2.2.2.tar.gz (85 kB)\r\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m86.0/86.0 kB\u001b[0m \u001b[31m3.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\r\n", + "\u001b[?25h Preparing metadata (setup.py) ... \u001b[?25l-\b \bdone\r\n", + "\u001b[?25hRequirement already satisfied: transformers<5.0.0,>=4.6.0 in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (4.27.4)\r\n", + "Requirement already satisfied: tqdm in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (4.64.1)\r\n", + "Requirement already satisfied: torch>=1.6.0 in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (1.13.0)\r\n", + "Requirement already satisfied: torchvision in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (0.14.0)\r\n", + "Requirement already satisfied: numpy in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (1.21.6)\r\n", + "Requirement already satisfied: scikit-learn in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (1.0.2)\r\n", + "Requirement already satisfied: scipy in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (1.7.3)\r\n", + "Requirement already satisfied: nltk in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (3.2.4)\r\n", + "Requirement already satisfied: sentencepiece in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (0.1.97)\r\n", + "Requirement already satisfied: huggingface-hub>=0.4.0 in /opt/conda/lib/python3.7/site-packages (from sentence-transformers) (0.13.3)\r\n", + "Requirement already satisfied: filelock in /opt/conda/lib/python3.7/site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (3.9.0)\r\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /opt/conda/lib/python3.7/site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (4.4.0)\r\n", + "Requirement already satisfied: importlib-metadata in /opt/conda/lib/python3.7/site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (4.11.4)\r\n", + "Requirement already satisfied: packaging>=20.9 in /opt/conda/lib/python3.7/site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (23.0)\r\n", + "Requirement already satisfied: requests in /opt/conda/lib/python3.7/site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (2.28.2)\r\n", + "Requirement already satisfied: pyyaml>=5.1 in /opt/conda/lib/python3.7/site-packages (from huggingface-hub>=0.4.0->sentence-transformers) (6.0)\r\n", + "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /opt/conda/lib/python3.7/site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (0.13.2)\r\n", + "Requirement already satisfied: regex!=2019.12.17 in /opt/conda/lib/python3.7/site-packages (from transformers<5.0.0,>=4.6.0->sentence-transformers) (2021.11.10)\r\n", + "Requirement already satisfied: six in /opt/conda/lib/python3.7/site-packages (from nltk->sentence-transformers) (1.16.0)\r\n", + "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sentence-transformers) (1.2.0)\r\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.7/site-packages (from scikit-learn->sentence-transformers) (3.1.0)\r\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /opt/conda/lib/python3.7/site-packages (from torchvision->sentence-transformers) (9.4.0)\r\n", + "Requirement already satisfied: zipp>=0.5 in /opt/conda/lib/python3.7/site-packages (from importlib-metadata->huggingface-hub>=0.4.0->sentence-transformers) (3.11.0)\r\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /opt/conda/lib/python3.7/site-packages (from requests->huggingface-hub>=0.4.0->sentence-transformers) (2.1.1)\r\n", + "Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/lib/python3.7/site-packages (from requests->huggingface-hub>=0.4.0->sentence-transformers) (2022.12.7)\r\n", + "Requirement already satisfied: idna<4,>=2.5 in /opt/conda/lib/python3.7/site-packages (from requests->huggingface-hub>=0.4.0->sentence-transformers) (3.4)\r\n", + "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /opt/conda/lib/python3.7/site-packages (from requests->huggingface-hub>=0.4.0->sentence-transformers) (1.26.14)\r\n", + "Building wheels for collected packages: sentence-transformers\r\n", + " Building wheel for sentence-transformers (setup.py) ... \u001b[?25l-\b \b\\\b \bdone\r\n", + "\u001b[?25h Created wheel for sentence-transformers: filename=sentence_transformers-2.2.2-py3-none-any.whl size=125938 sha256=9371fe0d23f9159127c11d05844caddaf0e7e7a0705a610d7054d2a05288ccec\r\n", + " Stored in directory: /root/.cache/pip/wheels/83/71/2b/40d17d21937fed496fb99145227eca8f20b4891240ff60c86f\r\n", + "Successfully built sentence-transformers\r\n", + "Installing collected packages: sentence-transformers\r\n", + "Successfully installed sentence-transformers-2.2.2\r\n", + "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\r\n", + "\u001b[0m" + ] + } + ], + "source": [ + "!pip install sentence-transformers" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d7495aea", + "metadata": { + "_cell_guid": "b1076dfc-b9ad-4769-8c92-a6c4dae69d19", + "_uuid": "8f2839f25d086af736a60e9eeb907d3b93b6e0e5", + "execution": { + "iopub.execute_input": "2023-04-16T20:10:15.519547Z", + "iopub.status.busy": "2023-04-16T20:10:15.519198Z", + "iopub.status.idle": "2023-04-16T20:10:22.872720Z", + "shell.execute_reply": "2023-04-16T20:10:22.871550Z" + }, + "papermill": { + "duration": 7.361118, + "end_time": "2023-04-16T20:10:22.875455", + "exception": false, + "start_time": "2023-04-16T20:10:15.514337", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import datasets\n", + "from datasets import Dataset\n", + "import numpy as np\n", + "import json\n", + "import os\n", + "import torch\n", + "import sentence_transformers\n", + "from tqdm.notebook import tqdm as tqdm\n", + "from sentence_transformers import SentenceTransformer, InputExample, losses, evaluation\n", + "from torch.utils.data import DataLoader\n", + "\n", + "\n", + "np.random.seed(42)\n", + "torch.manual_seed(42)\n", + "\n", + "INPUT_PATH = '/kaggle/input/ysda-ml-02-05-process-json/articles.hf'\n", + "OUTPUT_PATH = '/kaggle/working/model.pt'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "de7086f7", + "metadata": { + "execution": { + "iopub.execute_input": "2023-04-16T20:10:22.884000Z", + "iopub.status.busy": "2023-04-16T20:10:22.883145Z", + "iopub.status.idle": "2023-04-16T20:10:51.737079Z", + "shell.execute_reply": "2023-04-16T20:10:51.736023Z" + }, + "papermill": { + "duration": 28.861196, + "end_time": "2023-04-16T20:10:51.739985", + "exception": false, + "start_time": "2023-04-16T20:10:22.878789", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "articles = Dataset.load_from_disk(INPUT_PATH).to_dict()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e88a0172", + "metadata": { + "execution": { + "iopub.execute_input": "2023-04-16T20:10:51.748231Z", + "iopub.status.busy": "2023-04-16T20:10:51.747906Z", + "iopub.status.idle": "2023-04-16T20:10:51.753405Z", + "shell.execute_reply": "2023-04-16T20:10:51.752319Z" + }, + "papermill": { + "duration": 0.01218, + "end_time": "2023-04-16T20:10:51.755667", + "exception": false, + "start_time": "2023-04-16T20:10:51.743487", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "base_model_name = 'sentence-transformers/all-MiniLM-L6-v2'\n", + "new_model_name = 'eremeev-d/all-MiniLM-L6-v2-arxiv-fine-tuned'\n", + "epochs = 1\n", + "batch_size = 80\n", + "train_positive_samples_size = 10**5\n", + "train_negative_samples_size = 10*train_positive_samples_size\n", + "eval_positive_samples_size = 10**3\n", + "eval_negative_samples_size = eval_positive_samples_size\n", + "evaluation_steps = train_positive_samples_size // (batch_size * 5)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d9090048", + "metadata": { + "execution": { + "iopub.execute_input": "2023-04-16T20:10:51.763804Z", + "iopub.status.busy": "2023-04-16T20:10:51.763449Z", + "iopub.status.idle": "2023-04-16T20:11:05.699250Z", + "shell.execute_reply": "2023-04-16T20:11:05.698050Z" + }, + "papermill": { + "duration": 13.944232, + "end_time": "2023-04-16T20:11:05.703302", + "exception": false, + "start_time": "2023-04-16T20:10:51.759070", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "855aa21d23e64ac8a30cb9e1874ed817", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + " 0%| | 0/100000 [00:00