File size: 4,577 Bytes
f55b152 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter # For TensorBoard
from utils import save_checkpoint, load_checkpoint, print_examples
from dataset import get_loader
from model import SeqToSeq
from tabulate import tabulate # To tabulate loss and epoch
import argparse
import json
def main(args):
transform = transforms.Compose(
[
transforms.Resize((356, 356)),
transforms.RandomCrop((299, 299)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
train_loader, _ = get_loader(
root_folder = args.root_dir,
annotation_file = args.csv_file,
transform=transform,
batch_size = 64,
num_workers=2,
)
vocab = json.load(open('vocab.json'))
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_model = False
save_model = True
train_CNN = False
# Hyperparameters
embed_size = args.embed_size
hidden_size = args.hidden_size
vocab_size = len(vocab['stoi'])
num_layers = args.num_layers
learning_rate = args.lr
num_epochs = args.num_epochs
# for tensorboard
writer = SummaryWriter(args.log_dir)
step = 0
model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers}
# initialize model, loss etc
model = SeqToSeq(**model_params, device = device).to(device)
criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Only finetune the CNN
for name, param in model.encoder.inception.named_parameters():
if "fc.weight" in name or "fc.bias" in name:
param.requires_grad = True
else:
param.requires_grad = train_CNN
#load from a save checkpoint
if load_model:
step = load_checkpoint(torch.load(args.save_path), model, optimizer)
model.train()
best_loss, best_epoch = 10, 0
for epoch in range(num_epochs):
print_examples(model, device, vocab['itos'])
for idx, (imgs, captions) in tqdm(
enumerate(train_loader), total=len(train_loader), leave=False):
imgs = imgs.to(device)
captions = captions.to(device)
outputs = model(imgs, captions[:-1])
loss = criterion(
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
)
writer.add_scalar("Training loss", loss.item(), global_step=step)
step += 1
optimizer.zero_grad()
loss.backward(loss)
optimizer.step()
train_loss = loss.item()
if train_loss < best_loss:
best_loss = train_loss
best_epoch = epoch + 1
if save_model:
checkpoint = {
"model_params": model_params,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"step": step
}
save_checkpoint(checkpoint, args.save_path)
table = [["Loss:", train_loss],
["Step:", step],
["Epoch:", epoch + 1],
["Best Loss:", best_loss],
["Best Epoch:", best_epoch]]
print(tabulate(table))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder')
parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file')
parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs')
parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint')
# Model Params
parser.add_argument('--batch_size', type = int, default = 64)
parser.add_argument('--num_epochs', type = int, default = 100)
parser.add_argument('--embed_size', type = int, default=256)
parser.add_argument('--hidden_size', type = int, default=512)
parser.add_argument('--lr', type = float, default= 0.001)
parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers')
args = parser.parse_args()
main(args) |