{ "cells": [ { "cell_type": "markdown", "source": [ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n", "\n", "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "%pip install accelerate -U" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nNote: you may need to restart the kernel to use updated packages.\n" } ], "execution_count": 1, "metadata": { "jupyter": { "source_hidden": false, "outputs_hidden": false }, "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "%pip install transformers datasets shap watermark wandb" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" }, { "output_type": "stream", "name": "stdout", "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nCollecting wandb\n Using cached wandb-0.16.2-py3-none-any.whl (2.2 MB)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nCollecting sentry-sdk>=1.0.0\n Using cached sentry_sdk-1.39.2-py2.py3-none-any.whl (254 kB)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nCollecting docker-pycreds>=0.4.0\n Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nCollecting setproctitle\n Using cached setproctitle-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)\nCollecting appdirs>=1.4.3\n Using cached appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nInstalling collected packages: appdirs, setproctitle, sentry-sdk, docker-pycreds, wandb\nSuccessfully installed appdirs-1.4.4 docker-pycreds-0.4.0 sentry-sdk-1.39.2 setproctitle-1.3.3 wandb-0.16.2\nNote: you may need to restart the kernel to use updated packages.\n" } ], "execution_count": 17, "metadata": { "jupyter": { "source_hidden": false, "outputs_hidden": false }, "nteract": { "transient": { "deleting": false } } } }, { "cell_type": "code", "source": [ "import pandas as pd\n", "import numpy as np\n", "import torch\n", "import os\n", "from typing import List\n", "from sklearn.metrics import f1_score, accuracy_score, classification_report\n", "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n", "from datasets import load_dataset\n", "import shap\n", "\n", "%load_ext watermark" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 02:27:28.730200: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 02:27:29.708865: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 02:27:29.708983: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 02:27:29.708996: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n" } ], "execution_count": 3, "metadata": { "datalore": { "node_id": "caZjjFP0OyQNMVgZDiwswE", "type": "CODE", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "un8W7ez7ZwoGb5Co6nydEV" } }, "gather": { "logged": 1706408851775 } } }, { "cell_type": "code", "source": [ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "SEED: int = 42\n", "\n", "BATCH_SIZE: int = 8\n", "EPOCHS: int = 1\n", "model_ckpt: str = \"distilbert-base-uncased\"\n", "\n", "CLASS_NAMES: List[str] = [\"DIED\",\n", " \"ER_VISIT\",\n", " \"HOSPITAL\",\n", " \"OFC_VISIT\",\n", " \"X_STAY\",\n", " \"DISABLE\",\n", " \"D_PRESENTED\"]\n", "\n", "# WandB configuration\n", "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n", "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints" ], "outputs": [], "execution_count": 4, "metadata": { "collapsed": false, "gather": { "logged": 1706408852045 } } }, { "cell_type": "code", "source": [ "%watermark --iversion" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "re : 2.2.1\nnumpy : 1.23.5\nlogging: 0.5.1.2\npandas : 2.0.2\ntorch : 1.12.0\nshap : 0.44.1\n\n" } ], "execution_count": 5, "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "!nvidia-smi" ], "outputs": [ { "output_type": "stream", "name": "stdout", "text": "Sun Jan 28 02:27:31 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 28C P0 37W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 27C P0 36W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n" } ], "execution_count": 6, "metadata": { "datalore": { "node_id": "UU2oOJhwbIualogG1YyCMd", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true } } }, { "attachments": {}, "cell_type": "markdown", "source": [ "## Loading the data set" ], "metadata": { "datalore": { "node_id": "t45KHugmcPVaO0nuk8tGJ9", "type": "MD", "hide_input_from_viewers": false, "hide_output_from_viewers": false, "report_properties": { "rowId": "40nN9Hvgi1clHNV5RAemI5" } } } }, { "cell_type": "code", "source": [ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")" ], "outputs": [], "execution_count": 7, "metadata": { "collapsed": false, "gather": { "logged": 1706408853264 } } }, { "cell_type": "markdown", "source": [ "### Tokenisation and encoding" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)" ], "outputs": [], "execution_count": 8, "metadata": { "datalore": { "node_id": "I7n646PIscsUZRoHu6m7zm", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408853475 } } }, { "cell_type": "code", "source": [ "def tokenize_and_encode(examples):\n", " return tokenizer(examples[\"text\"], truncation=True)" ], "outputs": [], "execution_count": 9, "metadata": { "datalore": { "node_id": "QBLOSI0yVIslV7v7qX9ZC3", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408853684 } } }, { "cell_type": "code", "source": [ "cols = dataset[\"train\"].column_names\n", "cols.remove(\"labels\")\n", "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "Map: 100%|██████████| 15786/15786 [00:01<00:00, 10990.82 examples/s]\n" } ], "execution_count": 10, "metadata": { "datalore": { "node_id": "slHeNysZOX9uWS9PB7jFDb", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408854738 } } }, { "cell_type": "markdown", "source": [ "### Training" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "class MultiLabelTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False):\n", " labels = inputs.pop(\"labels\")\n", " outputs = model(**inputs)\n", " logits = outputs.logits\n", " loss_fct = torch.nn.BCEWithLogitsLoss()\n", " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n", " labels.float().view(-1, self.model.config.num_labels))\n", " return (loss, outputs) if return_outputs else loss" ], "outputs": [], "execution_count": 11, "metadata": { "datalore": { "node_id": "itXWkbDw9sqbkMuDP84QoT", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408854925 } } }, { "cell_type": "code", "source": [ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" } ], "execution_count": 12, "metadata": { "datalore": { "node_id": "ZQU7aW6TV45VmhHOQRzcnF", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408857008 } } }, { "cell_type": "code", "source": [ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n", " y_pred = torch.from_numpy(y_pred)\n", " y_true = torch.from_numpy(y_true)\n", "\n", " if sigmoid:\n", " y_pred = y_pred.sigmoid()\n", "\n", " return ((y_pred > threshold) == y_true.bool()).float().mean().item()" ], "outputs": [], "execution_count": 13, "metadata": { "datalore": { "node_id": "swhgyyyxoGL8HjnXJtMuSW", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408857297 } } }, { "cell_type": "code", "source": [ "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}" ], "outputs": [], "execution_count": 14, "metadata": { "datalore": { "node_id": "1Uq3HtkaBxtHNAnSwit5cI", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408857499 } } }, { "cell_type": "code", "source": [ "args = TrainingArguments(\n", " output_dir=\"vaers\",\n", " evaluation_strategy=\"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=BATCH_SIZE,\n", " per_device_eval_batch_size=BATCH_SIZE,\n", " num_train_epochs=EPOCHS,\n", " weight_decay=.01,\n", " report_to=[\"wandb\"]\n", ")" ], "outputs": [], "execution_count": 15, "metadata": { "datalore": { "node_id": "1iPZOTKPwSkTgX5dORqT89", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408857680 } } }, { "cell_type": "code", "source": [ "multi_label_trainer = MultiLabelTrainer(\n", " model, \n", " args, \n", " train_dataset=ds_enc[\"train\"], \n", " eval_dataset=ds_enc[\"test\"], \n", " compute_metrics=compute_metrics, \n", " tokenizer=tokenizer\n", ")" ], "outputs": [ { "output_type": "stream", "name": "stderr", "text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" } ], "execution_count": 18, "metadata": { "datalore": { "node_id": "bnRkNvRYltLun6gCEgL7v0", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408895305 } } }, { "cell_type": "code", "source": [ "multi_label_trainer.evaluate()" ], "outputs": [ { "output_type": "display_data", "data": { "text/plain": "", "text/html": "\n
\n \n \n [987/987 21:41]\n
\n " }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n" }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": "Tracking run with wandb version 0.16.2" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": "Run data is saved locally in /mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_022947-hh1sxw9i" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": "Syncing run icy-firebrand-1 to Weights & Biases (docs)
" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": " View project at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training" }, "metadata": {} }, { "output_type": "display_data", "data": { "text/plain": "", "text/html": " View run at https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/hh1sxw9i" }, "metadata": {} }, { "output_type": "execute_result", "execution_count": 19, "data": { "text/plain": "{'eval_loss': 0.7153111100196838,\n 'eval_accuracy_thresh': 0.2938227355480194,\n 'eval_runtime': 82.3613,\n 'eval_samples_per_second': 191.668,\n 'eval_steps_per_second': 11.984}" }, "metadata": {} } ], "execution_count": 19, "metadata": { "datalore": { "node_id": "LO54PlDkWQdFrzV25FvduB", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706408991752 } } }, { "cell_type": "code", "source": [ "multi_label_trainer.train()" ], "outputs": [ { "output_type": "display_data", "data": { "text/plain": "", "text/html": "\n
\n \n \n [4605/4605 20:25, Epoch 1/1]\n
\n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n \n
EpochTraining LossValidation LossAccuracy Thresh
10.0867000.0933880.962897

" }, "metadata": {} }, { "output_type": "stream", "name": "stderr", "text": "Checkpoint destination directory vaers/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.9s\nCheckpoint destination directory vaers/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 12.5s\nCheckpoint destination directory vaers/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 21.9s\nCheckpoint destination directory vaers/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 13.8s\nCheckpoint destination directory vaers/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 15.7s\nCheckpoint destination directory vaers/checkpoint-3000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... Done. 21.7s\nCheckpoint destination directory vaers/checkpoint-3500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3500)... Done. 10.6s\nCheckpoint destination directory vaers/checkpoint-4000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-4000)... Done. 15.0s\nCheckpoint destination directory vaers/checkpoint-4500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-4500)... Done. 16.7s\n" }, { "output_type": "execute_result", "execution_count": 21, "data": { "text/plain": "TrainOutput(global_step=4605, training_loss=0.09062977189220382, metrics={'train_runtime': 1223.2444, 'train_samples_per_second': 60.223, 'train_steps_per_second': 3.765, 'total_flos': 9346797199425174.0, 'train_loss': 0.09062977189220382, 'epoch': 1.0})" }, "metadata": {} } ], "execution_count": 21, "metadata": { "datalore": { "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411445752 } } }, { "cell_type": "markdown", "source": [ "### Evaluation" ], "metadata": { "collapsed": false } }, { "cell_type": "markdown", "source": [ "We instantiate a classifier `pipeline` and push it to CUDA." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "classifier = pipeline(\"text-classification\", \n", " model, \n", " tokenizer=tokenizer, \n", " device=\"cuda:0\")" ], "outputs": [], "execution_count": 24, "metadata": { "datalore": { "node_id": "kHoUdBeqcyVXDSGv54C4aE", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411459928 } } }, { "cell_type": "markdown", "source": [ "We use the same tokenizer used for training to tokenize/encode the validation set." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n", " max_length=255, \n", " pad_to_max_length=True, \n", " return_token_type_ids=True, \n", " truncation=True)" ], "outputs": [ { "output_type": "error", "ename": "KeyError", "evalue": "'validate'", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", "Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m test_encodings \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_encode_plus(\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvalidate\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m\"\u001b[39m], \n\u001b[1;32m 2\u001b[0m max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m255\u001b[39m, \n\u001b[1;32m 3\u001b[0m pad_to_max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, \n\u001b[1;32m 4\u001b[0m return_token_type_ids\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, \n\u001b[1;32m 5\u001b[0m truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n", "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/datasets/dataset_dict.py:74\u001b[0m, in \u001b[0;36mDatasetDict.__getitem__\u001b[0;34m(self, k)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, k) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dataset:\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(k, (\u001b[38;5;28mstr\u001b[39m, NamedSplit)) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 76\u001b[0m available_suggested_splits \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 77\u001b[0m split \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m (Split\u001b[38;5;241m.\u001b[39mTRAIN, Split\u001b[38;5;241m.\u001b[39mTEST, Split\u001b[38;5;241m.\u001b[39mVALIDATION) \u001b[38;5;28;01mif\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m 78\u001b[0m ]\n", "\u001b[0;31mKeyError\u001b[0m: 'validate'" ] } ], "execution_count": 25, "metadata": { "datalore": { "node_id": "Dr5WCWA6jL51NR1fSrQu6Z", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411465538 } } }, { "cell_type": "markdown", "source": [ "Once we've made the data loadable by putting it into a `DataLoader`, we " ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n", " torch.tensor(test_encodings['attention_mask']), \n", " torch.tensor(ds_enc[\"validate\"][\"labels\"]), \n", " torch.tensor(test_encodings['token_type_ids']))\n", "test_dataloader = torch.utils.data.DataLoader(test_data, \n", " sampler=torch.utils.data.SequentialSampler(test_data), \n", " batch_size=BATCH_SIZE)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "MWfGq2tTkJNzFiDoUPq2X7", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446707 } } }, { "cell_type": "code", "source": [ "model.eval()\n", "\n", "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n", "\n", "for i, batch in enumerate(test_dataloader):\n", " batch = tuple(t.to(device) for t in batch)\n", " # Unpack the inputs from our dataloader\n", " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n", " \n", " with torch.no_grad():\n", " outs = model(b_input_ids, attention_mask=b_input_mask)\n", " b_logit_pred = outs[0]\n", " pred_label = torch.sigmoid(b_logit_pred)\n", "\n", " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n", " pred_label = pred_label.to('cpu').numpy()\n", " b_labels = b_labels.to('cpu').numpy()\n", "\n", " tokenized_texts.append(b_input_ids)\n", " logit_preds.append(b_logit_pred)\n", " true_labels.append(b_labels)\n", " pred_labels.append(pred_label)\n", "\n", "# Flatten outputs\n", "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n", "pred_labels = [item for sublist in pred_labels for item in sublist]\n", "true_labels = [item for sublist in true_labels for item in sublist]\n", "\n", "# Converting flattened binary values to boolean values\n", "true_bools = [tl == 1 for tl in true_labels]\n", "pred_bools = [pl > 0.50 for pl in pred_labels] " ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "1SJCSrQTRCexFCNCIyRrzL", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446723 } } }, { "cell_type": "markdown", "source": [ "We create a classification report:" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n", "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n", "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n", "print(clf_report)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "eBprrgF086mznPbPVBpOLS", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446746 } } }, { "cell_type": "markdown", "source": [ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "# Creating a map of class names from class numbers\n", "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "yELHY0IEwMlMw3x6e7hoD1", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446758 } } }, { "cell_type": "code", "source": [ "true_label_idxs, pred_label_idxs = [], []\n", "\n", "for vals in true_bools:\n", " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n", "for vals in pred_bools:\n", " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "jH0S35dDteUch01sa6me6e", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446771 } } }, { "cell_type": "code", "source": [ "true_label_texts, pred_label_texts = [], []\n", "\n", "for vals in true_label_idxs:\n", " if vals:\n", " true_label_texts.append([idx2label[val] for val in vals])\n", " else:\n", " true_label_texts.append(vals)\n", "\n", "for vals in pred_label_idxs:\n", " if vals:\n", " pred_label_texts.append([idx2label[val] for val in vals])\n", " else:\n", " pred_label_texts.append(vals)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "h4vHL8XdGpayZ6xLGJUF6F", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446785 } } }, { "cell_type": "code", "source": [ "symptom_texts = [tokenizer.decode(text,\n", " skip_special_tokens=True,\n", " clean_up_tokenization_spaces=False) for text in tokenized_texts]" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "SxUmVHfQISEeptg1SawOmB", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446805 } } }, { "cell_type": "code", "source": [ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n", " 'true_labels': true_label_texts, \n", " 'pred_labels':pred_label_texts})\n", "comparisons_df.to_csv('comparisons.csv')\n", "comparisons_df" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "BxFNigNGRLTOqraI55BPSH", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446818 } } }, { "cell_type": "markdown", "source": [ "### Shapley analysis" ], "metadata": { "collapsed": false } }, { "cell_type": "code", "source": [ "explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "OpdZcoenX2HwzLdai7K5UA", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446829 } } }, { "cell_type": "code", "source": [ "shap_values = explainer(dataset[\"validate\"][\"text\"][1:2])" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "FvbCMfIDlcf16YSvb8wNQv", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446839 } } }, { "cell_type": "code", "source": [ "shap.plots.text(shap_values)" ], "outputs": [], "execution_count": null, "metadata": { "datalore": { "node_id": "TSxvakWLPCpjVMWi9ZdEbd", "type": "CODE", "hide_input_from_viewers": true, "hide_output_from_viewers": true }, "gather": { "logged": 1706411446848 } } }, { "cell_type": "code", "source": [], "outputs": [], "execution_count": null, "metadata": { "jupyter": { "source_hidden": false, "outputs_hidden": false }, "nteract": { "transient": { "deleting": false } } } } ], "metadata": { "kernelspec": { "name": "python3", "language": "python", "display_name": "Python 3 (ipykernel)" }, "datalore": { "computation_mode": "JUPYTER", "package_manager": "pip", "base_environment": "default", "packages": [ { "name": "datasets", "version": "2.16.1", "source": "PIP" }, { "name": "torch", "version": "2.1.2", "source": "PIP" }, { "name": "accelerate", "version": "0.26.1", "source": "PIP" } ], "report_row_ids": [ "un8W7ez7ZwoGb5Co6nydEV", "40nN9Hvgi1clHNV5RAemI5", "TgRD90H5NSPpKS41OeXI1w", "ZOm5BfUs3h1EGLaUkBGeEB", "kOP0CZWNSk6vqE3wkPp7Vc", "W4PWcOu2O2pRaZyoE2W80h", "RolbOnQLIftk0vy9mIcz5M", "8OPhUgbaNJmOdiq5D3a6vK", "5Qrt3jSvSrpK6Ne1hS6shL", "hTq7nFUrovN5Ao4u6dIYWZ", "I8WNZLpJ1DVP2wiCW7YBIB", "SawhU3I9BewSE1XBPstpNJ", "80EtLEl2FIE4FqbWnUD3nT" ], "version": 3 }, "microsoft": { "ms_spell_check": { "ms_spell_check_language": "en" } }, "language_info": { "name": "python", "version": "3.8.5", "mimetype": "text/x-python", "codemirror_mode": { "name": "ipython", "version": 3 }, "pygments_lexer": "ipython3", "nbconvert_exporter": "python", "file_extension": ".py" }, "nteract": { "version": "nteract-front-end@1.0.0" } }, "nbformat": 4, "nbformat_minor": 4 }