{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train with Pytorch" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/huggingface/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n", "/home/huggingface/lib/python3.10/site-packages/huggingface_hub/file_download.py:1150: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n", " warnings.warn(\n", "Map: 100%|██████████| 872/872 [00:00<00:00, 15492.15 examples/s]\n" ] } ], "source": [ "from datasets import load_dataset\n", "from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification\n", "\n", "raw_dataset = load_dataset(\"glue\", \"sst2\")\n", "checkpoint = \"bert-base-uncased\"\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint)\n", "\n", "# # For MRPC\n", "# def tokenize_function(sample):\n", "# return tokenizer(sample[\"sentence1\"], sample[\"sentence2\"], truncation = True)\n", "\n", "# For SST2\n", "def tokenize_function(sample):\n", " return tokenizer(sample[\"sentence\"], truncation = True)\n", "\n", "\n", "tokenized_dataset = raw_dataset.map(tokenize_function, batched = True)\n", "data_collator = DataCollatorWithPadding(tokenizer = tokenizer)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Preprocess the dataset " ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'train': ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n", " 'validation': ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],\n", " 'test': ['labels', 'input_ids', 'token_type_ids', 'attention_mask']}" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Remove unwanted columns which are not to be uitilized during pytorch dataloading\n", "# # For MRPC\n", "# tokenized_dataset = tokenized_dataset.remove_columns([\"sentence1\", \"sentence2\", \"idx\"])\n", "\n", "# For SST2\n", "tokenized_dataset = tokenized_dataset.remove_columns([\"sentence\", \"idx\"])\n", "\n", "# Rename the target column appropriately\n", "tokenized_dataset = tokenized_dataset.rename_column(\"label\", \"labels\")\n", "\n", "# Set the format to return tensors instead of lists\n", "tokenized_dataset.set_format(\"torch\")\n", "\n", "tokenized_dataset.column_names" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "\n", "train_dataloader = DataLoader(tokenized_dataset[\"train\"], shuffle = True, batch_size = 64, collate_fn = data_collator)\n", "eval_dataloader = DataLoader(tokenized_dataset[\"validation\"], batch_size = 64, collate_fn= data_collator)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n" ] }, { "data": { "text/plain": [ "{'labels': torch.Size([64]),\n", " 'input_ids': torch.Size([64, 41]),\n", " 'token_type_ids': torch.Size([64, 41]),\n", " 'attention_mask': torch.Size([64, 41])}" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "one_batch = next(iter(train_dataloader))\n", "{k: v.shape for k, v in one_batch.items()}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Define the model and start training" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']\n", "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] } ], "source": [ "model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels = 2)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SequenceClassifierOutput(loss=tensor(0.7528), logits=tensor([[-0.4735, 0.2345],\n", " [-0.5462, 0.2849],\n", " [-0.8623, 0.6073],\n", " [-0.6334, 0.3747],\n", " [-0.5882, 0.4656],\n", " [-0.1711, 0.1957],\n", " [-0.4656, 0.2387],\n", " [-0.8434, 0.6939],\n", " [-0.4384, 0.2810],\n", " [-0.5239, 0.2832],\n", " [-0.4431, 0.2877],\n", " [-0.5974, 0.2958],\n", " [-0.7655, 0.6273],\n", " [-0.7656, 0.6703],\n", " [-0.7001, 0.4183],\n", " [-0.3617, 0.2145],\n", " [-0.6250, 0.3684],\n", " [-0.5722, 0.4677],\n", " [-0.1536, 0.1978],\n", " [-0.5606, 0.3755],\n", " [-0.6292, 0.3662],\n", " [-0.7420, 0.3527],\n", " [-0.4581, 0.2733],\n", " [-0.6560, 0.4098],\n", " [-0.2436, 0.1589],\n", " [-0.5316, 0.2916],\n", " [-0.6136, 0.3340],\n", " [-0.6650, 0.3447],\n", " [-0.6319, 0.4982],\n", " [-0.7093, 0.4292],\n", " [-0.3495, 0.2136],\n", " [-0.5344, 0.2056],\n", " [-0.2243, 0.2376],\n", " [-0.2150, 0.2638],\n", " [-0.6236, 0.4449],\n", " [-0.3363, 0.2330],\n", " [-0.7103, 0.5592],\n", " [-0.6709, 0.4674],\n", " [-0.6250, 0.4823],\n", " [-0.8934, 0.8637],\n", " [-0.7147, 0.4695],\n", " [-0.4029, 0.2238],\n", " [-0.6455, 0.4327],\n", " [-0.2547, 0.2432],\n", " [-0.3518, 0.3581],\n", " [-0.1312, 0.1507],\n", " [-0.5558, 0.4219],\n", " [-0.4881, 0.3416],\n", " [-0.6623, 0.4497],\n", " [-0.5963, 0.4848],\n", " [-0.5053, 0.3500],\n", " [-0.1152, 0.1482],\n", " [-0.6302, 0.3531],\n", " [-0.6268, 0.4978],\n", " [-0.4811, 0.2927],\n", " [ 0.0057, 0.1694],\n", " [-0.6268, 0.3306],\n", " [-0.5859, 0.4029],\n", " [-0.3552, 0.2425],\n", " [-0.5622, 0.4161],\n", " [-0.7670, 0.5203],\n", " [-0.6624, 0.5146],\n", " [-0.6089, 0.4091],\n", " [-0.4992, 0.2702]]), hidden_states=None, attentions=None)\n" ] } ], "source": [ "import torch\n", "model.eval()\n", "with torch.no_grad():\n", " print(model(**one_batch))" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/huggingface/lib/python3.10/site-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "2106\n" ] } ], "source": [ "from transformers import AdamW\n", "from transformers import get_scheduler\n", "\n", "# Define the optimizer here\n", "optimizer = AdamW(model.parameters(), lr = 5e-5)\n", "\n", "# Define the learning rate scheduler here\n", "num_epochs = 2\n", "num_training_steps = num_epochs * len(train_dataloader)\n", "lr_scheduler = get_scheduler(\n", " \"linear\",\n", " optimizer=optimizer,\n", " num_warmup_steps=0,\n", " num_training_steps=num_training_steps,\n", ")\n", "print(num_training_steps)\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "# Use GPU if available\n", "device = torch.device(\"cuda:0\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", "model.to(device);" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ " 50%|█████ | 1054/2106 [03:48<13:25, 1.31it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Metrics at end of epoch 0:\n", "{'accuracy': 0.9288990825688074}\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "100%|█████████▉| 2105/2106 [07:35<00:00, 4.98it/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Metrics at end of epoch 1:\n", "{'accuracy': 0.926605504587156}\n" ] } ], "source": [ "from tqdm.auto import tqdm\n", "import evaluate\n", "progress_bar = tqdm(range(num_training_steps))\n", "\n", "for epoch_id in range(num_epochs):\n", "\n", " # Train for one epoch\n", " model.train()\n", " for batch in train_dataloader:\n", " batch = {k: v.to(device) for k, v in batch.items()}\n", " outputs = model(**batch)\n", " outputs.loss.backward()\n", "\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", " progress_bar.update(1)\n", "\n", " # Evaluate at the end of epoch\n", " model.eval()\n", " # # For MRPC\n", " # metric = evaluate.load(\"glue\", \"mrpc\")\n", "\n", " # For SST2\n", " metric = evaluate.load(\"glue\", \"sst2\")\n", "\n", " with torch.no_grad():\n", " for batch in eval_dataloader:\n", " batch = {k: v.to(device) for k, v in batch.items()}\n", " outputs = model(**batch)\n", " logits = outputs.logits\n", " predictions = logits.argmax(dim = -1)\n", " metric.add_batch(predictions = predictions, references = batch[\"labels\"])\n", " m = metric.compute()\n", "\n", " print(f\"Metrics at end of epoch {epoch_id}:\\n{m}\")\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }