clip_gpt2 / neuralnet /train.py
Vageesh1's picture
Upload 4 files
4e527a6
raw
history blame
No virus
4.58 kB
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter # For TensorBoard
from utils import save_checkpoint, load_checkpoint, print_examples
from dataset import get_loader
from model import SeqToSeq
from tabulate import tabulate # To tabulate loss and epoch
import argparse
import json
def main(args):
transform = transforms.Compose(
[
transforms.Resize((356, 356)),
transforms.RandomCrop((299, 299)),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
train_loader, _ = get_loader(
root_folder = args.root_dir,
annotation_file = args.csv_file,
transform=transform,
batch_size = 64,
num_workers=2,
)
vocab = json.load(open('vocab.json'))
torch.backends.cudnn.benchmark = True
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
load_model = False
save_model = True
train_CNN = False
# Hyperparameters
embed_size = args.embed_size
hidden_size = args.hidden_size
vocab_size = len(vocab['stoi'])
num_layers = args.num_layers
learning_rate = args.lr
num_epochs = args.num_epochs
# for tensorboard
writer = SummaryWriter(args.log_dir)
step = 0
model_params = {'embed_size': embed_size, 'hidden_size': hidden_size, 'vocab_size':vocab_size, 'num_layers':num_layers}
# initialize model, loss etc
model = SeqToSeq(**model_params, device = device).to(device)
criterion = nn.CrossEntropyLoss(ignore_index = vocab['stoi']["<PAD>"])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# Only finetune the CNN
for name, param in model.encoder.inception.named_parameters():
if "fc.weight" in name or "fc.bias" in name:
param.requires_grad = True
else:
param.requires_grad = train_CNN
#load from a save checkpoint
if load_model:
step = load_checkpoint(torch.load(args.save_path), model, optimizer)
model.train()
best_loss, best_epoch = 10, 0
for epoch in range(num_epochs):
print_examples(model, device, vocab['itos'])
for idx, (imgs, captions) in tqdm(
enumerate(train_loader), total=len(train_loader), leave=False):
imgs = imgs.to(device)
captions = captions.to(device)
outputs = model(imgs, captions[:-1])
loss = criterion(
outputs.reshape(-1, outputs.shape[2]), captions.reshape(-1)
)
writer.add_scalar("Training loss", loss.item(), global_step=step)
step += 1
optimizer.zero_grad()
loss.backward(loss)
optimizer.step()
train_loss = loss.item()
if train_loss < best_loss:
best_loss = train_loss
best_epoch = epoch + 1
if save_model:
checkpoint = {
"model_params": model_params,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
"step": step
}
save_checkpoint(checkpoint, args.save_path)
table = [["Loss:", train_loss],
["Step:", step],
["Epoch:", epoch + 1],
["Best Loss:", best_loss],
["Best Epoch:", best_epoch]]
print(tabulate(table))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--root_dir', type = str, default = './flickr30k/flickr30k_images', help = 'path to images folder')
parser.add_argument('--csv_file', type = str, default = './flickr30k/results.csv', help = 'path to captions csv file')
parser.add_argument('--log_dir', type = str, default = './drive/MyDrive/TensorBoard/', help = 'path to save tensorboard logs')
parser.add_argument('--save_path', type = str, default = './drive/MyDrive/checkpoints/Seq2Seq.pt', help = 'path to save checkpoint')
# Model Params
parser.add_argument('--batch_size', type = int, default = 64)
parser.add_argument('--num_epochs', type = int, default = 100)
parser.add_argument('--embed_size', type = int, default=256)
parser.add_argument('--hidden_size', type = int, default=512)
parser.add_argument('--lr', type = float, default= 0.001)
parser.add_argument('--num_layers', type = int, default = 3, help = 'number of lstm layers')
args = parser.parse_args()
main(args)