{ "cells": [ { "cell_type": "markdown", "id": "53a990e3-0d47-4e66-b928-f40d67f06584", "metadata": {}, "source": [ "# Setup" ] }, { "cell_type": "markdown", "id": "51fb0d43-c12b-4892-95d2-074bf5de0ce2", "metadata": {}, "source": [ "## Install addition packages" ] }, { "cell_type": "code", "execution_count": 1, "id": "9cf48779-454b-4b1d-b78f-531a1b207276", "metadata": { "tags": [] }, "outputs": [], "source": [ "import os\n", "\n", "# The Google Cloud Notebook product has specific requirements\n", "IS_GOOGLE_CLOUD_NOTEBOOK = os.path.exists(\"/opt/deeplearning/metadata/env_version\")\n", "\n", "# Google Cloud Notebook requires dependencies to be installed with '--user'\n", "USER_FLAG = \"\"\n", "if IS_GOOGLE_CLOUD_NOTEBOOK:\n", " USER_FLAG = \"--user\"" ] }, { "cell_type": "code", "execution_count": 2, "id": "d2a3556a-ebf1-49c7-9d2c-63e30ca45f73", "metadata": { "tags": [] }, "outputs": [], "source": [ "%%capture\n", "!pip -q install {USER_FLAG} --upgrade transformers\n", "!pip -q install {USER_FLAG} --upgrade datasets\n", "!pip -q install {USER_FLAG} --upgrade tqdm\n", "!pip -q install {USER_FLAG} --upgrade cloudml-hypertune" ] }, { "cell_type": "code", "execution_count": 3, "id": "fcc3f1f6-36d3-4056-ad29-b69c57bb0bac", "metadata": { "tags": [] }, "outputs": [], "source": [ "%%capture\n", "!pip -q install {USER_FLAG} --upgrade google-cloud-aiplatform" ] }, { "cell_type": "code", "execution_count": 4, "id": "2214d165-356d-47f1-a4ee-4f6c50027e96", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Automatically restart kernel after installs\n", "import os\n", "\n", "if not os.getenv(\"IS_TESTING\"):\n", " # Automatically restart kernel after installs\n", " import IPython\n", "\n", " app = IPython.Application.instance()\n", " app.kernel.do_shutdown(True)" ] }, { "cell_type": "code", "execution_count": 1, "id": "e8817443-c80e-475b-b54e-dd834c040b12", "metadata": {}, "outputs": [], "source": [ "%%capture\n", "!pip install git+https://github.com/huggingface/transformers.git datasets pandas torch\n", "!pip install transformers[torch]\n", "!pip install accelerate -U" ] }, { "cell_type": "markdown", "id": "21cc7690-95bf-4452-abef-46cd318ccfb5", "metadata": {}, "source": [ "## Set Project ID" ] }, { "cell_type": "code", "execution_count": 2, "id": "30b78533-ff39-4c92-a365-f2e05ddb642f", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Project ID: ikame-gem-ai-research\n" ] } ], "source": [ "PROJECT_ID = \"iKame-gem-ai-research\" # <---CHANGE THIS TO YOUR PROJECT\n", "\n", "import os\n", "\n", "# Get your Google Cloud project ID using google.auth\n", "if not os.getenv(\"IS_TESTING\"):\n", " import google.auth\n", "\n", " _, PROJECT_ID = google.auth.default()\n", " print(\"Project ID: \", PROJECT_ID)\n", "\n", "# validate PROJECT_ID\n", "if PROJECT_ID == \"\" or PROJECT_ID is None or PROJECT_ID == \"iKame-gem-ai-research\":\n", " print(\n", " f\"Please set your project id before proceeding to next step. Currently it's set as {PROJECT_ID}\"\n", " )" ] }, { "cell_type": "code", "execution_count": 3, "id": "5c4631f5-c8ba-43e9-a623-08cb2cb3a51a", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "TIMESTAMP = 20240108040502\n" ] } ], "source": [ "from datetime import datetime\n", "\n", "\n", "def get_timestamp():\n", " return datetime.now().strftime(\"%Y%m%d%H%M%S\")\n", "\n", "\n", "TIMESTAMP = get_timestamp()\n", "print(f\"TIMESTAMP = {TIMESTAMP}\")" ] }, { "cell_type": "markdown", "id": "494d8009-7f9a-45d8-ba7c-3e3205d1c96b", "metadata": {}, "source": [ "## Create Cloud Storage bucket" ] }, { "cell_type": "code", "execution_count": 4, "id": "303136a0-6334-4889-b43b-9f171a934311", "metadata": { "tags": [] }, "outputs": [], "source": [ "BUCKET_NAME = \"gs://iKame-gem-ai-research\" # <---CHANGE THIS TO YOUR BUCKET\n", "REGION = \"us-central1\" # @param {type:\"string\"}" ] }, { "cell_type": "code", "execution_count": 5, "id": "014c6208-0b1a-4da8-888b-19c02a112474", "metadata": { "tags": [] }, "outputs": [], "source": [ "if BUCKET_NAME == \"\" or BUCKET_NAME is None or BUCKET_NAME == \"gs://iKame-gem-ai-research\":\n", " BUCKET_NAME = f\"gs://{PROJECT_ID}-bucket-review\"" ] }, { "cell_type": "code", "execution_count": 6, "id": "a52a28fa-591e-487c-bd53-8f770441ba63", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "PROJECT_ID = ikame-gem-ai-research\n", "BUCKET_NAME = gs://ikame-gem-ai-research-bucket-review\n", "REGION = us-central1\n" ] } ], "source": [ "print(f\"PROJECT_ID = {PROJECT_ID}\")\n", "print(f\"BUCKET_NAME = {BUCKET_NAME}\")\n", "print(f\"REGION = {REGION}\")" ] }, { "cell_type": "code", "execution_count": 7, "id": "24c35eb2-7619-4958-a04a-79b62788f257", "metadata": { "tags": [] }, "outputs": [], "source": [ "# ! gsutil mb -l $REGION $BUCKET_NAME" ] }, { "cell_type": "code", "execution_count": 8, "id": "6f2ee0a0-3cff-47cb-9379-6f6e75fef9d5", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " 3078 2024-01-05T01:42:25Z gs://ikame-gem-ai-research-bucket-review/batch_examples.csv#1704418945853255 metageneration=1\n", " gs://ikame-gem-ai-research-bucket-review/pipeline_root/\n", "TOTAL: 1 objects, 3078 bytes (3.01 KiB)\n" ] } ], "source": [ "! gsutil ls -al $BUCKET_NAME #validate access to your Cloud Storage bucket" ] }, { "cell_type": "markdown", "id": "da865a4c-5e29-465e-abf2-e443dae1b573", "metadata": {}, "source": [ "## Install libraries" ] }, { "cell_type": "code", "execution_count": 9, "id": "fedbebaf-516e-4f7d-8a70-c7dc31de02df", "metadata": { "tags": [] }, "outputs": [], "source": [ "import base64\n", "import json\n", "import os\n", "import random\n", "import sys\n", "\n", "import google.auth\n", "from google.cloud import aiplatform\n", "from google.cloud.aiplatform import gapic as aip\n", "from google.cloud.aiplatform import hyperparameter_tuning as hpt\n", "from google.protobuf.json_format import MessageToDict" ] }, { "cell_type": "code", "execution_count": 10, "id": "0cc75279-b7a9-47cc-81a4-f8729c7d57f8", "metadata": { "tags": [] }, "outputs": [], "source": [ "from IPython.display import HTML, display" ] }, { "cell_type": "code", "execution_count": 11, "id": "8856c9f3-270f-4dca-8a10-6bdee1af8bc0", "metadata": { "tags": [] }, "outputs": [], "source": [ "import datasets\n", "from datasets import Dataset, DatasetDict\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "import transformers\n", "from datasets import ClassLabel, Sequence, load_dataset\n", "from transformers import (AutoModelForSequenceClassification, AutoTokenizer,BertForSequenceClassification,\n", " EvalPrediction, Trainer, TrainingArguments,PreTrainedModel,BertModel,\n", " default_data_collator)" ] }, { "cell_type": "code", "execution_count": 12, "id": "bbecdaa8-3cd3-4e7b-939d-f959da9301d6", "metadata": { "tags": [] }, "outputs": [], "source": [ "from google.cloud import bigquery\n", "from google.cloud import storage\n", "\n", "client = bigquery.Client()\n", "storage_client = storage.Client()" ] }, { "cell_type": "code", "execution_count": 13, "id": "f693060f-c0ed-4ec3-bc66-17898f8ef854", "metadata": { "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Notebook runtime: GPU\n", "PyTorch version : 2.0.0+cu118\n", "Transformers version : 2.16.1\n", "Datasets version : 4.37.0.dev0\n" ] } ], "source": [ "print(f\"Notebook runtime: {'GPU' if torch.cuda.is_available() else 'CPU'}\")\n", "print(f\"PyTorch version : {torch.__version__}\")\n", "print(f\"Transformers version : {datasets.__version__}\")\n", "print(f\"Datasets version : {transformers.__version__}\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "5637d9f0-d290-4107-974a-bfbda3b316b2", "metadata": { "tags": [] }, "outputs": [], "source": [ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"" ] }, { "cell_type": "code", "execution_count": 14, "id": "3d114e96-31c2-4ed9-82d1-f2fab38f0944", "metadata": { "tags": [] }, "outputs": [], "source": [ "APP_NAME = \"aift-review-classificatio-multiple-label\"" ] }, { "cell_type": "code", "execution_count": null, "id": "173dcb77-9908-4af1-86bb-7811c9f580e9", "metadata": {}, "outputs": [], "source": [ "!cd aift-model-review-multiple-label-classification" ] }, { "cell_type": "markdown", "id": "3f383051-501f-4f8c-8017-c989c5740041", "metadata": {}, "source": [ "# Training" ] }, { "cell_type": "markdown", "id": "db9715cc-0779-47a4-a0ed-82714b6668f6", "metadata": {}, "source": [ "## Preprocess data" ] }, { "cell_type": "code", "execution_count": 16, "id": "052ecc7b-c015-49a0-a359-85afbac10bbf", "metadata": { "tags": [] }, "outputs": [], "source": [ "model_ckpt = \"distilbert-base-uncased\"\n", "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n", "\n", "def tokenize_and_encode(examples):\n", " return tokenizer(examples[\"review\"], truncation=True)" ] }, { "cell_type": "code", "execution_count": 17, "id": "6f5faf02-ede8-4d48-b94a-1d4619c8e610", "metadata": { "tags": [] }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7a2415bdfd4a40fe80afe71e70d97976", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/556 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "3b1c36309d4e4e108e79578edc45ed56", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/140 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2b79b69e8457427781c8e6fc8ad54d82", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/556 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "e1e4981003d04646944fa0ce8ae0dc73", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/140 [00:00, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "sql = f\"\"\"\n", "SELECT * FROM `ikame-gem-ai-research.AIFT.reviews_multi_label_training`\n", "\"\"\"\n", "data = client.query(sql).to_dataframe()\n", "data= data.fillna('0')\n", "for i in data.columns:\n", " if i != 'review':\n", " data[i] = data[i].astype(int)\n", "\n", "data = Dataset.from_pandas(data).train_test_split(test_size=0.2,shuffle = True, seed=0)\n", "cols = data[\"train\"].column_names\n", "data = data.map(lambda x : {\"labels\": [x[c] for c in cols if c != \"review\"]})\n", "\n", "# Tokenize and encode\n", "dataset = data.map(tokenize_and_encode, batched=True, remove_columns=cols)" ] }, { "cell_type": "code", "execution_count": 18, "id": "f56a7de9-19a4-4cc8-996d-857c491cf633", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['ads', 'bugs', 'positive', 'negative', 'graphic', 'gameplay', 'request']" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "labels = [label for label in data['train'].features.keys() if label not in ['review','labels']]\n", "id2label = {idx:label for idx, label in enumerate(labels)}\n", "label2id = {label:idx for idx, label in enumerate(labels)}\n", "labels" ] }, { "cell_type": "code", "execution_count": 19, "id": "ad182dbc-c63d-49c9-b53c-9b63996d3746", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'labels': [0, 1, 0, 0, 0, 1, 0],\n", " 'input_ids': [101,\n", " 8795,\n", " 11100,\n", " 2024,\n", " 10599,\n", " 2030,\n", " 11829,\n", " 5999,\n", " 1010,\n", " 2437,\n", " 14967,\n", " 25198,\n", " 1012,\n", " 102],\n", " 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset[\"train\"][0]" ] }, { "cell_type": "markdown", "id": "02c2a7b2-58f1-4eac-ac61-5d54dbdc1184", "metadata": {}, "source": [ "## Fine-tuning" ] }, { "cell_type": "code", "execution_count": 20, "id": "9452f6f3-2b4b-4ee7-8c9f-3c42e04e396f", "metadata": { "tags": [] }, "outputs": [], "source": [ "class BertForMultilabelSequenceClassification(BertForSequenceClassification):\n", " def __init__(self, config):\n", " super().__init__(config)\n", "\n", " def forward(self,\n", " input_ids=None,\n", " attention_mask=None,\n", " token_type_ids=None,\n", " position_ids=None,\n", " head_mask=None,\n", " inputs_embeds=None,\n", " labels=None,\n", " output_attentions=None,\n", " output_hidden_states=None,\n", " return_dict=None):\n", " return_dict = return_dict if return_dict is not None else self.config.use_return_dict\n", "\n", " outputs = self.bert(input_ids,\n", " attention_mask=attention_mask,\n", " token_type_ids=token_type_ids,\n", " position_ids=position_ids,\n", " head_mask=head_mask,\n", " inputs_embeds=inputs_embeds,\n", " output_attentions=output_attentions,\n", " output_hidden_states=output_hidden_states,\n", " return_dict=return_dict)\n", "\n", " pooled_output = outputs[1]\n", " pooled_output = self.dropout(pooled_output)\n", " logits = self.classifier(pooled_output)\n", "\n", " loss = None\n", " if labels is not None:\n", " loss_fct = torch.nn.BCEWithLogitsLoss()\n", " loss = loss_fct(logits.view(-1, self.num_labels),\n", " labels.float().view(-1, self.num_labels))\n", "\n", " if not return_dict:\n", " output = (logits,) + outputs[2:]\n", " return ((loss,) + output) if loss is not None else output\n", "\n", " return SequenceClassifierOutput(loss=loss,\n", " logits=logits,\n", " hidden_states=outputs.hidden_states,\n", " attentions=outputs.attentions)" ] }, { "cell_type": "code", "execution_count": 21, "id": "76035010-b10a-4398-8a85-feaa19414ca4", "metadata": { "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You are using a model of type distilbert to instantiate a model of type bert. This is not supported for all configurations of models and can yield errors.\n", "Some weights of BertForMultilabelSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['encoder.layer.11.attention.self.key.bias', 'encoder.layer.6.attention.output.LayerNorm.bias', 'encoder.layer.3.attention.output.LayerNorm.bias', 'encoder.layer.11.attention.self.query.weight', 'encoder.layer.6.attention.self.value.bias', 'encoder.layer.4.output.LayerNorm.bias', 'encoder.layer.4.attention.self.key.bias', 'encoder.layer.9.output.LayerNorm.weight', 'encoder.layer.11.attention.self.query.bias', 'encoder.layer.11.intermediate.dense.weight', 'encoder.layer.1.output.LayerNorm.bias', 'encoder.layer.4.output.LayerNorm.weight', 'classifier.weight', 'encoder.layer.8.output.dense.bias', 'encoder.layer.9.attention.self.key.bias', 'encoder.layer.5.attention.self.key.bias', 'encoder.layer.5.intermediate.dense.bias', 'encoder.layer.3.attention.output.LayerNorm.weight', 'encoder.layer.7.attention.output.dense.bias', 'encoder.layer.1.attention.output.LayerNorm.bias', 'encoder.layer.1.output.dense.weight', 'encoder.layer.6.attention.output.LayerNorm.weight', 'encoder.layer.11.output.LayerNorm.bias', 'embeddings.token_type_embeddings.weight', 'encoder.layer.3.intermediate.dense.weight', 'encoder.layer.4.attention.self.key.weight', 'encoder.layer.11.attention.output.LayerNorm.weight', 'encoder.layer.6.intermediate.dense.weight', 'encoder.layer.9.attention.self.value.weight', 'embeddings.position_embeddings.weight', 'encoder.layer.10.attention.self.query.bias', 'encoder.layer.0.attention.output.dense.weight', 'encoder.layer.10.attention.self.key.weight', 'encoder.layer.2.attention.output.dense.bias', 'encoder.layer.3.attention.self.key.weight', 'encoder.layer.7.output.LayerNorm.bias', 'encoder.layer.2.attention.output.dense.weight', 'encoder.layer.5.attention.output.dense.weight', 'encoder.layer.8.attention.output.LayerNorm.bias', 'encoder.layer.2.intermediate.dense.bias', 'encoder.layer.11.intermediate.dense.bias', 'encoder.layer.4.intermediate.dense.weight', 'encoder.layer.6.output.dense.bias', 'encoder.layer.0.intermediate.dense.weight', 'encoder.layer.7.intermediate.dense.bias', 'encoder.layer.7.attention.self.value.bias', 'encoder.layer.6.attention.self.query.bias', 'encoder.layer.7.output.LayerNorm.weight', 'encoder.layer.3.attention.self.value.bias', 'encoder.layer.2.output.LayerNorm.weight', 'encoder.layer.10.intermediate.dense.bias', 'encoder.layer.2.attention.self.query.weight', 'encoder.layer.8.attention.output.dense.bias', 'encoder.layer.5.output.dense.bias', 'encoder.layer.9.attention.output.dense.bias', 'encoder.layer.9.attention.self.value.bias', 'encoder.layer.0.attention.output.LayerNorm.bias', 'encoder.layer.3.attention.output.dense.weight', 'encoder.layer.6.attention.self.key.bias', 'encoder.layer.1.attention.self.query.bias', 'encoder.layer.11.attention.self.value.weight', 'encoder.layer.10.intermediate.dense.weight', 'encoder.layer.5.attention.self.key.weight', 'encoder.layer.7.intermediate.dense.weight', 'encoder.layer.2.attention.self.key.bias', 'encoder.layer.7.output.dense.weight', 'encoder.layer.1.attention.output.LayerNorm.weight', 'encoder.layer.10.output.LayerNorm.bias', 'encoder.layer.5.output.LayerNorm.weight', 'encoder.layer.7.attention.output.dense.weight', 'encoder.layer.10.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.output.dense.weight', 'encoder.layer.9.attention.self.query.weight', 'encoder.layer.10.attention.output.LayerNorm.bias', 'encoder.layer.0.output.LayerNorm.bias', 'encoder.layer.10.attention.output.dense.bias', 'encoder.layer.1.output.LayerNorm.weight', 'encoder.layer.5.output.dense.weight', 'encoder.layer.5.attention.self.query.weight', 'classifier.bias', 'encoder.layer.5.intermediate.dense.weight', 'encoder.layer.1.intermediate.dense.weight', 'encoder.layer.1.attention.output.dense.bias', 'encoder.layer.3.attention.self.query.weight', 'encoder.layer.8.output.LayerNorm.bias', 'encoder.layer.3.output.dense.weight', 'encoder.layer.10.attention.self.value.weight', 'encoder.layer.6.output.dense.weight', 'encoder.layer.8.intermediate.dense.bias', 'encoder.layer.0.output.dense.bias', 'encoder.layer.4.attention.self.value.bias', 'encoder.layer.0.attention.self.key.bias', 'encoder.layer.4.attention.output.dense.bias', 'pooler.dense.bias', 'encoder.layer.10.attention.self.value.bias', 'encoder.layer.6.attention.self.key.weight', 'encoder.layer.10.attention.self.query.weight', 'encoder.layer.7.attention.output.LayerNorm.weight', 'encoder.layer.11.attention.self.value.bias', 'encoder.layer.10.attention.self.key.bias', 'encoder.layer.0.attention.self.key.weight', 'encoder.layer.9.attention.output.LayerNorm.bias', 'encoder.layer.11.attention.output.dense.weight', 'encoder.layer.7.attention.self.value.weight', 'encoder.layer.1.attention.self.value.weight', 'encoder.layer.3.intermediate.dense.bias', 'encoder.layer.9.attention.self.query.bias', 'embeddings.LayerNorm.weight', 'encoder.layer.5.attention.output.LayerNorm.bias', 'encoder.layer.1.output.dense.bias', 'encoder.layer.11.output.dense.bias', 'encoder.layer.2.output.dense.weight', 'encoder.layer.6.attention.self.value.weight', 'embeddings.LayerNorm.bias', 'encoder.layer.2.attention.self.value.bias', 'encoder.layer.0.intermediate.dense.bias', 'encoder.layer.11.attention.self.key.weight', 'encoder.layer.0.output.LayerNorm.weight', 'encoder.layer.9.intermediate.dense.bias', 'encoder.layer.3.output.LayerNorm.weight', 'encoder.layer.2.output.dense.bias', 'encoder.layer.11.attention.output.LayerNorm.bias', 'encoder.layer.9.output.dense.weight', 'encoder.layer.0.attention.self.value.bias', 'encoder.layer.0.attention.output.LayerNorm.weight', 'encoder.layer.4.output.dense.bias', 'encoder.layer.5.attention.self.value.weight', 'encoder.layer.9.output.dense.bias', 'encoder.layer.11.attention.output.dense.bias', 'encoder.layer.8.output.LayerNorm.weight', 'encoder.layer.0.attention.self.value.weight', 'encoder.layer.10.output.dense.weight', 'encoder.layer.9.output.LayerNorm.bias', 'encoder.layer.8.attention.self.query.weight', 'encoder.layer.9.intermediate.dense.weight', 'encoder.layer.10.output.LayerNorm.weight', 'encoder.layer.8.attention.self.value.bias', 'encoder.layer.1.attention.self.query.weight', 'encoder.layer.2.attention.output.LayerNorm.bias', 'encoder.layer.3.output.dense.bias', 'encoder.layer.4.attention.output.dense.weight', 'encoder.layer.5.output.LayerNorm.bias', 'encoder.layer.2.attention.self.key.weight', 'encoder.layer.5.attention.output.dense.bias', 'encoder.layer.11.output.dense.weight', 'encoder.layer.3.attention.self.query.bias', 'encoder.layer.0.output.dense.weight', 'encoder.layer.6.attention.output.dense.bias', 'encoder.layer.7.output.dense.bias', 'encoder.layer.2.attention.output.LayerNorm.weight', 'encoder.layer.6.output.LayerNorm.bias', 'encoder.layer.10.output.dense.bias', 'pooler.dense.weight', 'encoder.layer.0.attention.self.query.weight', 'encoder.layer.3.output.LayerNorm.bias', 'encoder.layer.3.attention.self.value.weight', 'encoder.layer.5.attention.output.LayerNorm.weight', 'encoder.layer.6.attention.self.query.weight', 'encoder.layer.8.attention.self.query.bias', 'encoder.layer.2.attention.self.query.bias', 'encoder.layer.2.intermediate.dense.weight', 'encoder.layer.4.attention.output.LayerNorm.bias', 'encoder.layer.8.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.dense.weight', 'encoder.layer.0.attention.output.dense.bias', 'encoder.layer.1.attention.self.key.weight', 'encoder.layer.3.attention.self.key.bias', 'encoder.layer.4.attention.self.query.weight', 'encoder.layer.7.attention.self.key.bias', 'encoder.layer.8.attention.self.key.weight', 'embeddings.word_embeddings.weight', 'encoder.layer.1.attention.output.dense.weight', 'encoder.layer.4.intermediate.dense.bias', 'encoder.layer.8.attention.self.key.bias', 'encoder.layer.7.attention.self.query.bias', 'encoder.layer.1.attention.self.key.bias', 'encoder.layer.4.output.dense.weight', 'encoder.layer.4.attention.self.query.bias', 'encoder.layer.3.attention.output.dense.bias', 'encoder.layer.4.attention.self.value.weight', 'encoder.layer.4.attention.output.LayerNorm.weight', 'encoder.layer.9.attention.output.LayerNorm.weight', 'encoder.layer.0.attention.self.query.bias', 'encoder.layer.7.attention.self.key.weight', 'encoder.layer.5.attention.self.query.bias', 'encoder.layer.8.intermediate.dense.weight', 'encoder.layer.8.attention.self.value.weight', 'encoder.layer.8.attention.output.dense.weight', 'encoder.layer.7.attention.output.LayerNorm.bias', 'encoder.layer.1.intermediate.dense.bias', 'encoder.layer.1.attention.self.value.bias', 'encoder.layer.2.attention.self.value.weight', 'encoder.layer.8.output.dense.weight', 'encoder.layer.11.output.LayerNorm.weight', 'encoder.layer.9.attention.self.key.weight', 'encoder.layer.2.output.LayerNorm.bias', 'encoder.layer.6.intermediate.dense.bias', 'encoder.layer.6.output.LayerNorm.weight', 'encoder.layer.7.attention.self.query.weight', 'encoder.layer.5.attention.self.value.bias', 'encoder.layer.10.attention.output.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "num_labels=7\n", "model = BertForMultilabelSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to('cuda')" ] }, { "cell_type": "code", "execution_count": 22, "id": "74af900d-0688-4f7b-b8f2-56f36f467a06", "metadata": { "tags": [] }, "outputs": [], "source": [ "def accuracy_thresh(y_pred, y_true, thresh=0.5, sigmoid=True):\n", " y_pred = torch.from_numpy(y_pred)\n", " y_true = torch.from_numpy(y_true)\n", " if sigmoid:\n", " y_pred = y_pred.sigmoid()\n", " return ((y_pred>thresh)==y_true.bool()).float().mean().item()" ] }, { "cell_type": "code", "execution_count": 23, "id": "db202a97-61e1-4e43-bb93-20179c2c0aa2", "metadata": { "tags": [] }, "outputs": [], "source": [ "def compute_metrics(eval_pred):\n", " predictions, labels = eval_pred\n", " return {'accuracy_thresh': accuracy_thresh(predictions, labels)}" ] }, { "cell_type": "code", "execution_count": 24, "id": "e0ab370a-fc4d-460b-9dab-dbde755dc3f4", "metadata": {}, "outputs": [], "source": [ "class MultilabelTrainer(Trainer):\n", " def compute_loss(self, model, inputs, return_outputs=False):\n", " labels = inputs.pop(\"labels\")\n", " outputs = model(**inputs)\n", " logits = outputs.logits\n", " loss_fct = torch.nn.BCEWithLogitsLoss()\n", " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n", " labels.float().view(-1, self.model.config.num_labels))\n", " return (loss, outputs) if return_outputs else loss" ] }, { "cell_type": "code", "execution_count": 32, "id": "340ade6d-1eb1-47ec-b8e6-56371083e361", "metadata": {}, "outputs": [], "source": [ "batch_size = 8\n", "\n", "args = TrainingArguments(\n", " output_dir=\"aift-model-review-multiple-label-classification\",\n", " evaluation_strategy = \"epoch\",\n", " learning_rate=2e-5,\n", " per_device_train_batch_size=batch_size,\n", " per_device_eval_batch_size=batch_size,\n", " num_train_epochs=10,\n", " weight_decay=0.01,\n", " use_cpu = False\n", ")" ] }, { "cell_type": "code", "execution_count": 33, "id": "39d8e955-9ca8-463c-899a-bd3b1d5f2c0e", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to('cuda')" ] }, { "cell_type": "code", "execution_count": 34, "id": "3cb96e02-f0f7-4a0a-9fe6-f88fe89826f8", "metadata": {}, "outputs": [], "source": [ "trainer = MultilabelTrainer(\n", " model,\n", " args,\n", " train_dataset=dataset[\"train\"],\n", " eval_dataset=dataset[\"test\"],\n", " compute_metrics=compute_metrics,\n", " tokenizer=tokenizer)" ] }, { "cell_type": "code", "execution_count": 35, "id": "da79a882-f1f1-41a5-b4dd-98b070012c4c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "\n", "
Epoch | \n", "Training Loss | \n", "Validation Loss | \n", "Accuracy Thresh | \n", "
---|---|---|---|
1 | \n", "No log | \n", "0.415191 | \n", "0.868367 | \n", "
2 | \n", "No log | \n", "0.302631 | \n", "0.901020 | \n", "
3 | \n", "No log | \n", "0.240627 | \n", "0.928571 | \n", "
4 | \n", "No log | \n", "0.217601 | \n", "0.931633 | \n", "
5 | \n", "No log | \n", "0.203845 | \n", "0.924490 | \n", "
6 | \n", "No log | \n", "0.192444 | \n", "0.929592 | \n", "
7 | \n", "No log | \n", "0.190031 | \n", "0.926531 | \n", "
8 | \n", "0.265200 | \n", "0.186760 | \n", "0.928571 | \n", "
9 | \n", "0.265200 | \n", "0.180436 | \n", "0.936735 | \n", "
10 | \n", "0.265200 | \n", "0.179821 | \n", "0.934694 | \n", "
"
],
"text/plain": [
"