|
import torch |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
|
|
|
|
def print_examples(model, device, vocab): |
|
transform = transforms.Compose( |
|
[transforms.Resize((299, 299)), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] |
|
) |
|
|
|
model.eval() |
|
|
|
test_img1 = transform(Image.open("./test_examples/dog.png").convert("RGB")).unsqueeze(0) |
|
print("dog.png PREDICTION: " + " ".join(model.caption_image(test_img1.to(device), vocab))) |
|
|
|
test_img2 = transform(Image.open("./test_examples/dirt_bike.png").convert("RGB")).unsqueeze(0) |
|
print("dirt_bike.png PREDICTION: " + " ".join(model.caption_image(test_img2.to(device), vocab))) |
|
|
|
test_img3 = transform(Image.open("./test_examples/surfing.png").convert("RGB")).unsqueeze(0) |
|
print("wave.png PREDICTION: " + " ".join(model.caption_image(test_img3.to(device), vocab))) |
|
|
|
test_img4 = transform(Image.open("./test_examples/horse.png").convert("RGB")).unsqueeze(0) |
|
print("horse.png PREDICTION: " + " ".join(model.caption_image(test_img4.to(device), vocab))) |
|
|
|
test_img5 = transform(Image.open("./test_examples/camera.png").convert("RGB")).unsqueeze(0) |
|
print("camera.png PREDICTION: " + " ".join(model.caption_image(test_img5.to(device), vocab))) |
|
model.train() |
|
|
|
|
|
def save_checkpoint(state, filename="/content/drive/MyDrive/checkpoints/Seq2Seq.pt"): |
|
print("=> Saving checkpoint") |
|
torch.save(state, filename) |
|
|
|
|
|
def load_checkpoint(checkpoint, model, optimizer): |
|
print("=> Loading checkpoint") |
|
model.load_state_dict(checkpoint["state_dict"]) |
|
optimizer.load_state_dict(checkpoint["optimizer"]) |
|
step = checkpoint["step"] |
|
return step |
|
|