Image-to-Text
Chinese
English
FLIP / FLIP-demo /main.py
OpenFace-CQUPT
Upload 14 files
6e6d6a7 verified
raw
history blame
3.03 kB
import argparse
import numpy as np
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torch.distributed as dist
from torch.cuda.amp import GradScaler, autocast
from models.FFLIP import FLIP
from models import utils
from eval.pretrain_eval import evaluation, itm_eval
from data import create_dataset, create_sampler, create_loader
def main(args):
utils.init_distributed_mode(args)
device = torch.device(args.device)
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
cudnn.benchmark = True
#### The reference code for creating the dataset ####
print("Creating dataset")
train_dataset, test_dataset = create_dataset(args, 'facecaption')
if args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None]
else:
samplers = [None, None]
train_loader, test_loader = create_loader([train_dataset, test_dataset], samplers,
batch_size=[80] + [80],
num_workers=[8, 8],
is_trains=[True, False],
collate_fns=[None, None])
#### Model ####
print("Creating model")
model = FLIP(pretrained=args.pretrained, vit='base', queue_size=61440)
model = model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
print("Start evaluation")
score_test_i2t, score_test_t2i = evaluation(args, model_without_ddp, test_loader, device)
if utils.is_main_process():
test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img,
test_loader.dataset.img2txt)
print(test_result)
if args.distributed:
dist.barrier()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--output_dir', default='./outputs')
parser.add_argument('--img_root', default='./FaceCaption/images')
parser.add_argument('--ann_root', default='.FaceCaption/caption')
parser.add_argument('--pretrained', default='./FaceCaption-15M-base.pth')
parser.add_argument('--device', default='cuda')
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
parser.add_argument('--distributed', default=False, type=bool, help='whether to use distributed mode to training')
args = parser.parse_args()
main(args)