# Train with Pytorch

In [1]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForSequenceClassification

raw_dataset = load_dataset("glue", "sst2")
checkpoint = "bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)

# # For MRPC
# def tokenize_function(sample):
# return tokenizer(sample["sentence1"], sample["sentence2"], truncation = True)

# For SST2
def tokenize_function(sample):
 return tokenizer(sample["sentence"], truncation = True)


tokenized_dataset = raw_dataset.map(tokenize_function, batched = True)
data_collator = DataCollatorWithPadding(tokenizer = tokenizer)

 from .autonotebook import tqdm as notebook_tqdm
Map: 100%|██████████| 872/872 [00:00<00:00, 15492.15 examples/s]


# Preprocess the dataset 

In [2]:
# Remove unwanted columns which are not to be uitilized during pytorch dataloading
# # For MRPC
# tokenized_dataset = tokenized_dataset.remove_columns(["sentence1", "sentence2", "idx"])

# For SST2
tokenized_dataset = tokenized_dataset.remove_columns(["sentence", "idx"])

# Rename the target column appropriately
tokenized_dataset = tokenized_dataset.rename_column("label", "labels")

# Set the format to return tensors instead of lists
tokenized_dataset.set_format("torch")

tokenized_dataset.column_names

{'train': ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
 'validation': ['labels', 'input_ids', 'token_type_ids', 'attention_mask'],
 'test': ['labels', 'input_ids', 'token_type_ids', 'attention_mask']}

In [3]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(tokenized_dataset["train"], shuffle = True, batch_size = 64, collate_fn = data_collator)
eval_dataloader = DataLoader(tokenized_dataset["validation"], batch_size = 64, collate_fn= data_collator)

In [4]:
one_batch = next(iter(train_dataloader))
{k: v.shape for k, v in one_batch.items()}

You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'labels': torch.Size([64]),
 'input_ids': torch.Size([64, 41]),
 'token_type_ids': torch.Size([64, 41]),
 'attention_mask': torch.Size([64, 41])}

# Define the model and start training

In [5]:
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels = 2)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly i

In [6]:
import torch
model.eval()
with torch.no_grad():
 print(model(**one_batch))

SequenceClassifierOutput(loss=tensor(0.7528), logits=tensor([[-0.4735, 0.2345],
 [-0.5462, 0.2849],
 [-0.8623, 0.6073],
 [-0.6334, 0.3747],
 [-0.5882, 0.4656],
 [-0.1711, 0.1957],
 [-0.4656, 0.2387],
 [-0.8434, 0.6939],
 [-0.4384, 0.2810],
 [-0.5239, 0.2832],
 [-0.4431, 0.2877],
 [-0.5974, 0.2958],
 [-0.7655, 0.6273],
 [-0.7656, 0.6703],
 [-0.7001, 0.4183],
 [-0.3617, 0.2145],
 [-0.6250, 0.3684],
 [-0.5722, 0.4677],
 [-0.1536, 0.1978],
 [-0.5606, 0.3755],
 [-0.6292, 0.3662],
 [-0.7420, 0.3527],
 [-0.4581, 0.2733],
 [-0.6560, 0.4098],
 [-0.2436, 0.1589],
 [-0.5316, 0.2916],
 [-0.6136, 0.3340],
 [-0.6650, 0.3447],
 [-0.6319, 0.4982],
 [-0.7093, 0.4292],
 [-0.3495, 0.2136],
 [-0.5344, 0.2056],
 [-0.2243, 0.2376],
 [-0.2150, 0.2638],
 [-0.6236, 0.4449],
 [-0.3363, 0.2330],
 [-0.7103, 0.5592],
 [-0.6709, 0.4674],
 [-0.6250, 0.4823],
 [-0.8934, 0.8637],
 [-0.7147, 0.4695],
 [-0.4029, 0.2238],
 [-0.6455, 0.4327],
 [-0.2547, 0.2432],
 [-0.3518, 0.3581],
 [-0.1312, 0.1507],
 [-0.5558, 0.4219],


In [7]:
from transformers import AdamW
from transformers import get_scheduler

# Define the optimizer here
optimizer = AdamW(model.parameters(), lr = 5e-5)

# Define the learning rate scheduler here
num_epochs = 2
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
 "linear",
 optimizer=optimizer,
 num_warmup_steps=0,
 num_training_steps=num_training_steps,
)
print(num_training_steps)




2106


In [8]:
# Use GPU if available
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
model.to(device);

In [9]:
from tqdm.auto import tqdm
import evaluate
progress_bar = tqdm(range(num_training_steps))

for epoch_id in range(num_epochs):

 # Train for one epoch
 model.train()
 for batch in train_dataloader:
 batch = {k: v.to(device) for k, v in batch.items()}
 outputs = model(**batch)
 outputs.loss.backward()

 optimizer.step()
 lr_scheduler.step()
 optimizer.zero_grad()
 progress_bar.update(1)

 # Evaluate at the end of epoch
 model.eval()
 # # For MRPC
 # metric = evaluate.load("glue", "mrpc")

 # For SST2
 metric = evaluate.load("glue", "sst2")

 with torch.no_grad():
 for batch in eval_dataloader:
 batch = {k: v.to(device) for k, v in batch.items()}
 outputs = model(**batch)
 logits = outputs.logits
 predictions = logits.argmax(dim = -1)
 metric.add_batch(predictions = predictions, references = batch["labels"])
 m = metric.compute()

 print(f"Metrics at end of epoch {epoch_id}:\n{m}")


 50%|█████ | 1054/2106 [03:48<13:25, 1.31it/s]

Metrics at end of epoch 0:
{'accuracy': 0.9288990825688074}


100%|█████████▉| 2105/2106 [07:35<00:00, 4.98it/s]

Metrics at end of epoch 1:
{'accuracy': 0.926605504587156}
