Spaces:
Runtime error
Runtime error
the files
Browse files- LICENSE.txt +12 -0
- alarm.jpeg +0 -0
- alarm1.jpeg +0 -0
- app.py +36 -0
- eval_nocaps.py +118 -0
- requirements.txt +6 -0
- train_caption.py +206 -0
- train_nlvr.py +213 -0
- train_retrieval.py +345 -0
- train_vqa.py +202 -0
- utils.py +278 -0
- walk.jpeg +0 -0
LICENSE.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022, Salesforce.com, Inc.
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
5 |
+
|
6 |
+
* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
7 |
+
|
8 |
+
* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
9 |
+
|
10 |
+
* Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
11 |
+
|
12 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
alarm.jpeg
ADDED
alarm1.jpeg
ADDED
app.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from models.blip_vqa import blip_vqa
|
3 |
+
import torch
|
4 |
+
from torchvision import transforms
|
5 |
+
from torchvision.transforms.functional import InterpolationMode
|
6 |
+
|
7 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
8 |
+
image_size = 480
|
9 |
+
|
10 |
+
transform = transforms.Compose([
|
11 |
+
transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
|
12 |
+
transforms.ToTensor(),
|
13 |
+
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
|
14 |
+
])
|
15 |
+
|
16 |
+
model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
|
17 |
+
model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
|
18 |
+
model.eval()
|
19 |
+
model = model.to(device)
|
20 |
+
|
21 |
+
|
22 |
+
def pool_alarm(raw_image):
|
23 |
+
question = 'there is someone in the pool?'
|
24 |
+
image = transform(raw_image).unsqueeze(0).to(device)
|
25 |
+
with torch.no_grad():
|
26 |
+
answer = model(image, question, train=False, inference='generate')
|
27 |
+
|
28 |
+
return 'answer: ' + answer[0]
|
29 |
+
|
30 |
+
|
31 |
+
input = gr.inputs.Image(type='pil')
|
32 |
+
output = gr.outputs.Textbox()
|
33 |
+
examples = ['alarm.jpeg', 'alarm1.jpeg', 'walk.jpeg']
|
34 |
+
title = ""
|
35 |
+
description = ""
|
36 |
+
intf = gr.Interface(fn=pool_alarm, inputs=input, outputs=output, examples=examples).launch()
|
eval_nocaps.py
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip import blip_decoder
|
26 |
+
import utils
|
27 |
+
from data import create_dataset, create_sampler, create_loader
|
28 |
+
from data.utils import save_result
|
29 |
+
|
30 |
+
@torch.no_grad()
|
31 |
+
def evaluate(model, data_loader, device, config):
|
32 |
+
# evaluate
|
33 |
+
model.eval()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
header = 'Evaluation:'
|
37 |
+
print_freq = 10
|
38 |
+
|
39 |
+
result = []
|
40 |
+
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
|
41 |
+
|
42 |
+
image = image.to(device)
|
43 |
+
|
44 |
+
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
|
45 |
+
min_length=config['min_length'], repetition_penalty=1.1)
|
46 |
+
|
47 |
+
for caption, img_id in zip(captions, image_id):
|
48 |
+
result.append({"image_id": img_id.item(), "caption": caption})
|
49 |
+
|
50 |
+
return result
|
51 |
+
|
52 |
+
|
53 |
+
def main(args, config):
|
54 |
+
utils.init_distributed_mode(args)
|
55 |
+
|
56 |
+
device = torch.device(args.device)
|
57 |
+
|
58 |
+
# fix the seed for reproducibility
|
59 |
+
seed = args.seed + utils.get_rank()
|
60 |
+
torch.manual_seed(seed)
|
61 |
+
np.random.seed(seed)
|
62 |
+
random.seed(seed)
|
63 |
+
cudnn.benchmark = True
|
64 |
+
|
65 |
+
#### Dataset ####
|
66 |
+
print("Creating captioning dataset")
|
67 |
+
val_dataset, test_dataset = create_dataset('nocaps', config)
|
68 |
+
|
69 |
+
if args.distributed:
|
70 |
+
num_tasks = utils.get_world_size()
|
71 |
+
global_rank = utils.get_rank()
|
72 |
+
samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
|
73 |
+
else:
|
74 |
+
samplers = [None,None]
|
75 |
+
|
76 |
+
val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
|
77 |
+
batch_size=[config['batch_size']]*2,num_workers=[4,4],
|
78 |
+
is_trains=[False, False], collate_fns=[None,None])
|
79 |
+
|
80 |
+
#### Model ####
|
81 |
+
print("Creating model")
|
82 |
+
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
83 |
+
prompt=config['prompt'])
|
84 |
+
|
85 |
+
model = model.to(device)
|
86 |
+
|
87 |
+
model_without_ddp = model
|
88 |
+
if args.distributed:
|
89 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
90 |
+
model_without_ddp = model.module
|
91 |
+
|
92 |
+
val_result = evaluate(model_without_ddp, val_loader, device, config)
|
93 |
+
val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
|
94 |
+
test_result = evaluate(model_without_ddp, test_loader, device, config)
|
95 |
+
test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == '__main__':
|
99 |
+
parser = argparse.ArgumentParser()
|
100 |
+
parser.add_argument('--config', default='./configs/nocaps.yaml')
|
101 |
+
parser.add_argument('--output_dir', default='output/NoCaps')
|
102 |
+
parser.add_argument('--device', default='cuda')
|
103 |
+
parser.add_argument('--seed', default=42, type=int)
|
104 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
105 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
106 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
107 |
+
args = parser.parse_args()
|
108 |
+
|
109 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
110 |
+
|
111 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
112 |
+
|
113 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
114 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
117 |
+
|
118 |
+
main(args, config)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
timm==0.4.12
|
2 |
+
transformers==4.15.0
|
3 |
+
fairscale==0.4.4
|
4 |
+
pycocoevalcap
|
5 |
+
torch
|
6 |
+
torchvision
|
train_caption.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip import blip_decoder
|
26 |
+
import utils
|
27 |
+
from utils import cosine_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
from data.utils import save_result, coco_caption_eval
|
30 |
+
|
31 |
+
def train(model, data_loader, optimizer, epoch, device):
|
32 |
+
# train
|
33 |
+
model.train()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
37 |
+
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
38 |
+
header = 'Train Caption Epoch: [{}]'.format(epoch)
|
39 |
+
print_freq = 50
|
40 |
+
|
41 |
+
for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
42 |
+
image = image.to(device)
|
43 |
+
|
44 |
+
loss = model(image, caption)
|
45 |
+
|
46 |
+
optimizer.zero_grad()
|
47 |
+
loss.backward()
|
48 |
+
optimizer.step()
|
49 |
+
|
50 |
+
metric_logger.update(loss=loss.item())
|
51 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
52 |
+
|
53 |
+
# gather the stats from all processes
|
54 |
+
metric_logger.synchronize_between_processes()
|
55 |
+
print("Averaged stats:", metric_logger.global_avg())
|
56 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
57 |
+
|
58 |
+
|
59 |
+
@torch.no_grad()
|
60 |
+
def evaluate(model, data_loader, device, config):
|
61 |
+
# evaluate
|
62 |
+
model.eval()
|
63 |
+
|
64 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
65 |
+
header = 'Caption generation:'
|
66 |
+
print_freq = 10
|
67 |
+
|
68 |
+
result = []
|
69 |
+
for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
|
70 |
+
|
71 |
+
image = image.to(device)
|
72 |
+
|
73 |
+
captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
|
74 |
+
min_length=config['min_length'])
|
75 |
+
|
76 |
+
for caption, img_id in zip(captions, image_id):
|
77 |
+
result.append({"image_id": img_id.item(), "caption": caption})
|
78 |
+
|
79 |
+
return result
|
80 |
+
|
81 |
+
|
82 |
+
def main(args, config):
|
83 |
+
utils.init_distributed_mode(args)
|
84 |
+
|
85 |
+
device = torch.device(args.device)
|
86 |
+
|
87 |
+
# fix the seed for reproducibility
|
88 |
+
seed = args.seed + utils.get_rank()
|
89 |
+
torch.manual_seed(seed)
|
90 |
+
np.random.seed(seed)
|
91 |
+
random.seed(seed)
|
92 |
+
cudnn.benchmark = True
|
93 |
+
|
94 |
+
#### Dataset ####
|
95 |
+
print("Creating captioning dataset")
|
96 |
+
train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
|
97 |
+
|
98 |
+
if args.distributed:
|
99 |
+
num_tasks = utils.get_world_size()
|
100 |
+
global_rank = utils.get_rank()
|
101 |
+
samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
|
102 |
+
else:
|
103 |
+
samplers = [None, None, None]
|
104 |
+
|
105 |
+
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
|
106 |
+
batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
|
107 |
+
is_trains=[True, False, False], collate_fns=[None,None,None])
|
108 |
+
|
109 |
+
#### Model ####
|
110 |
+
print("Creating model")
|
111 |
+
model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
112 |
+
vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
|
113 |
+
prompt=config['prompt'])
|
114 |
+
|
115 |
+
model = model.to(device)
|
116 |
+
|
117 |
+
model_without_ddp = model
|
118 |
+
if args.distributed:
|
119 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
120 |
+
model_without_ddp = model.module
|
121 |
+
|
122 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
123 |
+
|
124 |
+
best = 0
|
125 |
+
best_epoch = 0
|
126 |
+
|
127 |
+
print("Start training")
|
128 |
+
start_time = time.time()
|
129 |
+
for epoch in range(0, config['max_epoch']):
|
130 |
+
if not args.evaluate:
|
131 |
+
if args.distributed:
|
132 |
+
train_loader.sampler.set_epoch(epoch)
|
133 |
+
|
134 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
135 |
+
|
136 |
+
train_stats = train(model, train_loader, optimizer, epoch, device)
|
137 |
+
|
138 |
+
val_result = evaluate(model_without_ddp, val_loader, device, config)
|
139 |
+
val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
|
140 |
+
|
141 |
+
test_result = evaluate(model_without_ddp, test_loader, device, config)
|
142 |
+
test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
|
143 |
+
|
144 |
+
if utils.is_main_process():
|
145 |
+
coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
|
146 |
+
coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
|
147 |
+
|
148 |
+
if args.evaluate:
|
149 |
+
log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
|
150 |
+
**{f'test_{k}': v for k, v in coco_test.eval.items()},
|
151 |
+
}
|
152 |
+
with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
|
153 |
+
f.write(json.dumps(log_stats) + "\n")
|
154 |
+
else:
|
155 |
+
save_obj = {
|
156 |
+
'model': model_without_ddp.state_dict(),
|
157 |
+
'optimizer': optimizer.state_dict(),
|
158 |
+
'config': config,
|
159 |
+
'epoch': epoch,
|
160 |
+
}
|
161 |
+
|
162 |
+
if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
|
163 |
+
best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
|
164 |
+
best_epoch = epoch
|
165 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
|
166 |
+
|
167 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
168 |
+
**{f'val_{k}': v for k, v in coco_val.eval.items()},
|
169 |
+
**{f'test_{k}': v for k, v in coco_test.eval.items()},
|
170 |
+
'epoch': epoch,
|
171 |
+
'best_epoch': best_epoch,
|
172 |
+
}
|
173 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
174 |
+
f.write(json.dumps(log_stats) + "\n")
|
175 |
+
|
176 |
+
if args.evaluate:
|
177 |
+
break
|
178 |
+
dist.barrier()
|
179 |
+
|
180 |
+
total_time = time.time() - start_time
|
181 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
182 |
+
print('Training time {}'.format(total_time_str))
|
183 |
+
|
184 |
+
|
185 |
+
if __name__ == '__main__':
|
186 |
+
parser = argparse.ArgumentParser()
|
187 |
+
parser.add_argument('--config', default='./configs/caption_coco.yaml')
|
188 |
+
parser.add_argument('--output_dir', default='output/Caption_coco')
|
189 |
+
parser.add_argument('--evaluate', action='store_true')
|
190 |
+
parser.add_argument('--device', default='cuda')
|
191 |
+
parser.add_argument('--seed', default=42, type=int)
|
192 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
193 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
194 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
195 |
+
args = parser.parse_args()
|
196 |
+
|
197 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
198 |
+
|
199 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
200 |
+
|
201 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
202 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
203 |
+
|
204 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
205 |
+
|
206 |
+
main(args, config)
|
train_nlvr.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
import json
|
18 |
+
import pickle
|
19 |
+
|
20 |
+
import torch
|
21 |
+
import torch.nn as nn
|
22 |
+
import torch.nn.functional as F
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
import torch.backends.cudnn as cudnn
|
25 |
+
import torch.distributed as dist
|
26 |
+
|
27 |
+
from models.blip_nlvr import blip_nlvr
|
28 |
+
|
29 |
+
import utils
|
30 |
+
from utils import cosine_lr_schedule, warmup_lr_schedule
|
31 |
+
from data import create_dataset, create_sampler, create_loader
|
32 |
+
|
33 |
+
def train(model, data_loader, optimizer, epoch, device, config):
|
34 |
+
# train
|
35 |
+
model.train()
|
36 |
+
|
37 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
38 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
|
39 |
+
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
|
40 |
+
|
41 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
42 |
+
print_freq = 50
|
43 |
+
step_size = 10
|
44 |
+
|
45 |
+
for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
46 |
+
|
47 |
+
images = torch.cat([image0, image1], dim=0)
|
48 |
+
images, targets = images.to(device), targets.to(device)
|
49 |
+
|
50 |
+
loss = model(images, text, targets=targets, train=True)
|
51 |
+
|
52 |
+
optimizer.zero_grad()
|
53 |
+
loss.backward()
|
54 |
+
optimizer.step()
|
55 |
+
|
56 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
57 |
+
metric_logger.update(loss=loss.item())
|
58 |
+
|
59 |
+
# gather the stats from all processes
|
60 |
+
metric_logger.synchronize_between_processes()
|
61 |
+
print("Averaged stats:", metric_logger.global_avg())
|
62 |
+
return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
63 |
+
|
64 |
+
|
65 |
+
@torch.no_grad()
|
66 |
+
def evaluate(model, data_loader, device, config):
|
67 |
+
# test
|
68 |
+
model.eval()
|
69 |
+
|
70 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
71 |
+
|
72 |
+
header = 'Evaluation:'
|
73 |
+
print_freq = 50
|
74 |
+
|
75 |
+
for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
|
76 |
+
images = torch.cat([image0, image1], dim=0)
|
77 |
+
images, targets = images.to(device), targets.to(device)
|
78 |
+
|
79 |
+
prediction = model(images, text, targets=targets, train=False)
|
80 |
+
|
81 |
+
_, pred_class = prediction.max(1)
|
82 |
+
accuracy = (targets==pred_class).sum() / targets.size(0)
|
83 |
+
|
84 |
+
metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
|
85 |
+
|
86 |
+
# gather the stats from all processes
|
87 |
+
metric_logger.synchronize_between_processes()
|
88 |
+
|
89 |
+
print("Averaged stats:", metric_logger.global_avg())
|
90 |
+
return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
def main(args, config):
|
95 |
+
utils.init_distributed_mode(args)
|
96 |
+
|
97 |
+
device = torch.device(args.device)
|
98 |
+
|
99 |
+
# fix the seed for reproducibility
|
100 |
+
seed = args.seed + utils.get_rank()
|
101 |
+
torch.manual_seed(seed)
|
102 |
+
np.random.seed(seed)
|
103 |
+
random.seed(seed)
|
104 |
+
cudnn.benchmark = True
|
105 |
+
|
106 |
+
#### Dataset ####
|
107 |
+
print("Creating dataset")
|
108 |
+
datasets = create_dataset('nlvr', config)
|
109 |
+
|
110 |
+
if args.distributed:
|
111 |
+
num_tasks = utils.get_world_size()
|
112 |
+
global_rank = utils.get_rank()
|
113 |
+
samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank)
|
114 |
+
else:
|
115 |
+
samplers = [None, None, None]
|
116 |
+
|
117 |
+
batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']]
|
118 |
+
train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size,
|
119 |
+
num_workers=[4,4,4],is_trains=[True,False,False],
|
120 |
+
collate_fns=[None,None,None])
|
121 |
+
|
122 |
+
#### Model ####
|
123 |
+
print("Creating model")
|
124 |
+
model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'],
|
125 |
+
vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
|
126 |
+
|
127 |
+
model = model.to(device)
|
128 |
+
|
129 |
+
model_without_ddp = model
|
130 |
+
if args.distributed:
|
131 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
132 |
+
model_without_ddp = model.module
|
133 |
+
|
134 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
135 |
+
|
136 |
+
print("Start training")
|
137 |
+
start_time = time.time()
|
138 |
+
best = 0
|
139 |
+
best_epoch = 0
|
140 |
+
|
141 |
+
for epoch in range(0, config['max_epoch']):
|
142 |
+
if not args.evaluate:
|
143 |
+
if args.distributed:
|
144 |
+
train_loader.sampler.set_epoch(epoch)
|
145 |
+
|
146 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
147 |
+
|
148 |
+
train_stats = train(model, train_loader, optimizer, epoch, device, config)
|
149 |
+
|
150 |
+
val_stats = evaluate(model, val_loader, device, config)
|
151 |
+
test_stats = evaluate(model, test_loader, device, config)
|
152 |
+
|
153 |
+
if utils.is_main_process():
|
154 |
+
if args.evaluate:
|
155 |
+
log_stats = {**{f'val_{k}': v for k, v in val_stats.items()},
|
156 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
157 |
+
}
|
158 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
159 |
+
f.write(json.dumps(log_stats) + "\n")
|
160 |
+
|
161 |
+
else:
|
162 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
163 |
+
**{f'val_{k}': v for k, v in val_stats.items()},
|
164 |
+
**{f'test_{k}': v for k, v in test_stats.items()},
|
165 |
+
'epoch': epoch,
|
166 |
+
}
|
167 |
+
|
168 |
+
if float(val_stats['acc'])>best:
|
169 |
+
save_obj = {
|
170 |
+
'model': model_without_ddp.state_dict(),
|
171 |
+
'optimizer': optimizer.state_dict(),
|
172 |
+
'config': config,
|
173 |
+
'epoch': epoch,
|
174 |
+
}
|
175 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
|
176 |
+
best = float(val_stats['acc'])
|
177 |
+
best_epoch = epoch
|
178 |
+
|
179 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
180 |
+
f.write(json.dumps(log_stats) + "\n")
|
181 |
+
if args.evaluate:
|
182 |
+
break
|
183 |
+
|
184 |
+
dist.barrier()
|
185 |
+
|
186 |
+
if utils.is_main_process():
|
187 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
188 |
+
f.write("best epoch: %d"%best_epoch)
|
189 |
+
|
190 |
+
total_time = time.time() - start_time
|
191 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
192 |
+
print('Training time {}'.format(total_time_str))
|
193 |
+
|
194 |
+
|
195 |
+
if __name__ == '__main__':
|
196 |
+
parser = argparse.ArgumentParser()
|
197 |
+
parser.add_argument('--config', default='./configs/nlvr.yaml')
|
198 |
+
parser.add_argument('--output_dir', default='output/NLVR')
|
199 |
+
parser.add_argument('--evaluate', action='store_true')
|
200 |
+
parser.add_argument('--device', default='cuda')
|
201 |
+
parser.add_argument('--seed', default=42, type=int)
|
202 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
203 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
204 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
205 |
+
args = parser.parse_args()
|
206 |
+
|
207 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
208 |
+
|
209 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
210 |
+
|
211 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
212 |
+
|
213 |
+
main(args, config)
|
train_retrieval.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
import torch.backends.cudnn as cudnn
|
22 |
+
import torch.distributed as dist
|
23 |
+
from torch.utils.data import DataLoader
|
24 |
+
|
25 |
+
from models.blip_retrieval import blip_retrieval
|
26 |
+
import utils
|
27 |
+
from utils import cosine_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
|
30 |
+
|
31 |
+
def train(model, data_loader, optimizer, epoch, device, config):
|
32 |
+
# train
|
33 |
+
model.train()
|
34 |
+
|
35 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
36 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
37 |
+
metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
38 |
+
metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
39 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
40 |
+
print_freq = 50
|
41 |
+
|
42 |
+
for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
43 |
+
image = image.to(device,non_blocking=True)
|
44 |
+
idx = idx.to(device,non_blocking=True)
|
45 |
+
|
46 |
+
if epoch>0:
|
47 |
+
alpha = config['alpha']
|
48 |
+
else:
|
49 |
+
alpha = config['alpha']*min(1,i/len(data_loader))
|
50 |
+
|
51 |
+
loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
|
52 |
+
loss = loss_ita + loss_itm
|
53 |
+
|
54 |
+
optimizer.zero_grad()
|
55 |
+
loss.backward()
|
56 |
+
optimizer.step()
|
57 |
+
|
58 |
+
metric_logger.update(loss_itm=loss_itm.item())
|
59 |
+
metric_logger.update(loss_ita=loss_ita.item())
|
60 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
61 |
+
|
62 |
+
# gather the stats from all processes
|
63 |
+
metric_logger.synchronize_between_processes()
|
64 |
+
print("Averaged stats:", metric_logger.global_avg())
|
65 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
66 |
+
|
67 |
+
|
68 |
+
@torch.no_grad()
|
69 |
+
def evaluation(model, data_loader, device, config):
|
70 |
+
# test
|
71 |
+
model.eval()
|
72 |
+
|
73 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
74 |
+
header = 'Evaluation:'
|
75 |
+
|
76 |
+
print('Computing features for evaluation...')
|
77 |
+
start_time = time.time()
|
78 |
+
|
79 |
+
texts = data_loader.dataset.text
|
80 |
+
num_text = len(texts)
|
81 |
+
text_bs = 256
|
82 |
+
text_ids = []
|
83 |
+
text_embeds = []
|
84 |
+
text_atts = []
|
85 |
+
for i in range(0, num_text, text_bs):
|
86 |
+
text = texts[i: min(num_text, i+text_bs)]
|
87 |
+
text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
|
88 |
+
text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
|
89 |
+
text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
|
90 |
+
text_embeds.append(text_embed)
|
91 |
+
text_ids.append(text_input.input_ids)
|
92 |
+
text_atts.append(text_input.attention_mask)
|
93 |
+
|
94 |
+
text_embeds = torch.cat(text_embeds,dim=0)
|
95 |
+
text_ids = torch.cat(text_ids,dim=0)
|
96 |
+
text_atts = torch.cat(text_atts,dim=0)
|
97 |
+
text_ids[:,0] = model.tokenizer.enc_token_id
|
98 |
+
|
99 |
+
image_feats = []
|
100 |
+
image_embeds = []
|
101 |
+
for image, img_id in data_loader:
|
102 |
+
image = image.to(device)
|
103 |
+
image_feat = model.visual_encoder(image)
|
104 |
+
image_embed = model.vision_proj(image_feat[:,0,:])
|
105 |
+
image_embed = F.normalize(image_embed,dim=-1)
|
106 |
+
|
107 |
+
image_feats.append(image_feat.cpu())
|
108 |
+
image_embeds.append(image_embed)
|
109 |
+
|
110 |
+
image_feats = torch.cat(image_feats,dim=0)
|
111 |
+
image_embeds = torch.cat(image_embeds,dim=0)
|
112 |
+
|
113 |
+
sims_matrix = image_embeds @ text_embeds.t()
|
114 |
+
score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
|
115 |
+
|
116 |
+
num_tasks = utils.get_world_size()
|
117 |
+
rank = utils.get_rank()
|
118 |
+
step = sims_matrix.size(0)//num_tasks + 1
|
119 |
+
start = rank*step
|
120 |
+
end = min(sims_matrix.size(0),start+step)
|
121 |
+
|
122 |
+
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
123 |
+
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
124 |
+
|
125 |
+
encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
|
126 |
+
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
|
127 |
+
output = model.text_encoder(text_ids[topk_idx],
|
128 |
+
attention_mask = text_atts[topk_idx],
|
129 |
+
encoder_hidden_states = encoder_output,
|
130 |
+
encoder_attention_mask = encoder_att,
|
131 |
+
return_dict = True,
|
132 |
+
)
|
133 |
+
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
134 |
+
score_matrix_i2t[start+i,topk_idx] = score + topk_sim
|
135 |
+
|
136 |
+
sims_matrix = sims_matrix.t()
|
137 |
+
score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
|
138 |
+
|
139 |
+
step = sims_matrix.size(0)//num_tasks + 1
|
140 |
+
start = rank*step
|
141 |
+
end = min(sims_matrix.size(0),start+step)
|
142 |
+
|
143 |
+
for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
|
144 |
+
|
145 |
+
topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
|
146 |
+
encoder_output = image_feats[topk_idx].to(device)
|
147 |
+
encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
|
148 |
+
output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
|
149 |
+
attention_mask = text_atts[start+i].repeat(config['k_test'],1),
|
150 |
+
encoder_hidden_states = encoder_output,
|
151 |
+
encoder_attention_mask = encoder_att,
|
152 |
+
return_dict = True,
|
153 |
+
)
|
154 |
+
score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
|
155 |
+
score_matrix_t2i[start+i,topk_idx] = score + topk_sim
|
156 |
+
|
157 |
+
if args.distributed:
|
158 |
+
dist.barrier()
|
159 |
+
torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
|
160 |
+
torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
|
161 |
+
|
162 |
+
total_time = time.time() - start_time
|
163 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
164 |
+
print('Evaluation time {}'.format(total_time_str))
|
165 |
+
|
166 |
+
return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
@torch.no_grad()
|
171 |
+
def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
|
172 |
+
|
173 |
+
#Images->Text
|
174 |
+
ranks = np.zeros(scores_i2t.shape[0])
|
175 |
+
for index,score in enumerate(scores_i2t):
|
176 |
+
inds = np.argsort(score)[::-1]
|
177 |
+
# Score
|
178 |
+
rank = 1e20
|
179 |
+
for i in img2txt[index]:
|
180 |
+
tmp = np.where(inds == i)[0][0]
|
181 |
+
if tmp < rank:
|
182 |
+
rank = tmp
|
183 |
+
ranks[index] = rank
|
184 |
+
|
185 |
+
# Compute metrics
|
186 |
+
tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
187 |
+
tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
188 |
+
tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
189 |
+
|
190 |
+
#Text->Images
|
191 |
+
ranks = np.zeros(scores_t2i.shape[0])
|
192 |
+
|
193 |
+
for index,score in enumerate(scores_t2i):
|
194 |
+
inds = np.argsort(score)[::-1]
|
195 |
+
ranks[index] = np.where(inds == txt2img[index])[0][0]
|
196 |
+
|
197 |
+
# Compute metrics
|
198 |
+
ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
|
199 |
+
ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
|
200 |
+
ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
|
201 |
+
|
202 |
+
tr_mean = (tr1 + tr5 + tr10) / 3
|
203 |
+
ir_mean = (ir1 + ir5 + ir10) / 3
|
204 |
+
r_mean = (tr_mean + ir_mean) / 2
|
205 |
+
|
206 |
+
eval_result = {'txt_r1': tr1,
|
207 |
+
'txt_r5': tr5,
|
208 |
+
'txt_r10': tr10,
|
209 |
+
'txt_r_mean': tr_mean,
|
210 |
+
'img_r1': ir1,
|
211 |
+
'img_r5': ir5,
|
212 |
+
'img_r10': ir10,
|
213 |
+
'img_r_mean': ir_mean,
|
214 |
+
'r_mean': r_mean}
|
215 |
+
return eval_result
|
216 |
+
|
217 |
+
|
218 |
+
def main(args, config):
|
219 |
+
utils.init_distributed_mode(args)
|
220 |
+
|
221 |
+
device = torch.device(args.device)
|
222 |
+
|
223 |
+
# fix the seed for reproducibility
|
224 |
+
seed = args.seed + utils.get_rank()
|
225 |
+
torch.manual_seed(seed)
|
226 |
+
np.random.seed(seed)
|
227 |
+
random.seed(seed)
|
228 |
+
cudnn.benchmark = True
|
229 |
+
|
230 |
+
#### Dataset ####
|
231 |
+
print("Creating retrieval dataset")
|
232 |
+
train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
|
233 |
+
|
234 |
+
if args.distributed:
|
235 |
+
num_tasks = utils.get_world_size()
|
236 |
+
global_rank = utils.get_rank()
|
237 |
+
samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
|
238 |
+
else:
|
239 |
+
samplers = [None, None, None]
|
240 |
+
|
241 |
+
train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
|
242 |
+
batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
|
243 |
+
num_workers=[4,4,4],
|
244 |
+
is_trains=[True, False, False],
|
245 |
+
collate_fns=[None,None,None])
|
246 |
+
|
247 |
+
|
248 |
+
#### Model ####
|
249 |
+
print("Creating model")
|
250 |
+
model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
|
251 |
+
vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
|
252 |
+
queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
|
253 |
+
|
254 |
+
model = model.to(device)
|
255 |
+
|
256 |
+
model_without_ddp = model
|
257 |
+
if args.distributed:
|
258 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
259 |
+
model_without_ddp = model.module
|
260 |
+
|
261 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
262 |
+
|
263 |
+
best = 0
|
264 |
+
best_epoch = 0
|
265 |
+
|
266 |
+
print("Start training")
|
267 |
+
start_time = time.time()
|
268 |
+
|
269 |
+
for epoch in range(0, config['max_epoch']):
|
270 |
+
if not args.evaluate:
|
271 |
+
if args.distributed:
|
272 |
+
train_loader.sampler.set_epoch(epoch)
|
273 |
+
|
274 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
275 |
+
|
276 |
+
train_stats = train(model, train_loader, optimizer, epoch, device, config)
|
277 |
+
|
278 |
+
score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
|
279 |
+
score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
|
280 |
+
|
281 |
+
if utils.is_main_process():
|
282 |
+
|
283 |
+
val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
|
284 |
+
print(val_result)
|
285 |
+
|
286 |
+
if val_result['r_mean']>best:
|
287 |
+
save_obj = {
|
288 |
+
'model': model_without_ddp.state_dict(),
|
289 |
+
'optimizer': optimizer.state_dict(),
|
290 |
+
'config': config,
|
291 |
+
'epoch': epoch,
|
292 |
+
}
|
293 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
|
294 |
+
best = val_result['r_mean']
|
295 |
+
best_epoch = epoch
|
296 |
+
|
297 |
+
test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
|
298 |
+
print(test_result)
|
299 |
+
|
300 |
+
if args.evaluate:
|
301 |
+
log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
|
302 |
+
**{f'test_{k}': v for k, v in test_result.items()},
|
303 |
+
}
|
304 |
+
with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
|
305 |
+
f.write(json.dumps(log_stats) + "\n")
|
306 |
+
else:
|
307 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
308 |
+
**{f'val_{k}': v for k, v in val_result.items()},
|
309 |
+
**{f'test_{k}': v for k, v in test_result.items()},
|
310 |
+
'epoch': epoch,
|
311 |
+
'best_epoch': best_epoch,
|
312 |
+
}
|
313 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
314 |
+
f.write(json.dumps(log_stats) + "\n")
|
315 |
+
|
316 |
+
if args.evaluate:
|
317 |
+
break
|
318 |
+
|
319 |
+
dist.barrier()
|
320 |
+
torch.cuda.empty_cache()
|
321 |
+
|
322 |
+
total_time = time.time() - start_time
|
323 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
324 |
+
print('Training time {}'.format(total_time_str))
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == '__main__':
|
328 |
+
parser = argparse.ArgumentParser()
|
329 |
+
parser.add_argument('--config', default='./configs/retrieval_flickr.yaml')
|
330 |
+
parser.add_argument('--output_dir', default='output/Retrieval_flickr')
|
331 |
+
parser.add_argument('--evaluate', action='store_true')
|
332 |
+
parser.add_argument('--device', default='cuda')
|
333 |
+
parser.add_argument('--seed', default=42, type=int)
|
334 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
335 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
336 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
337 |
+
args = parser.parse_args()
|
338 |
+
|
339 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
340 |
+
|
341 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
342 |
+
|
343 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
344 |
+
|
345 |
+
main(args, config)
|
train_vqa.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
* Copyright (c) 2022, salesforce.com, inc.
|
3 |
+
* All rights reserved.
|
4 |
+
* SPDX-License-Identifier: BSD-3-Clause
|
5 |
+
* For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
6 |
+
* By Junnan Li
|
7 |
+
'''
|
8 |
+
import argparse
|
9 |
+
import os
|
10 |
+
import ruamel_yaml as yaml
|
11 |
+
import numpy as np
|
12 |
+
import random
|
13 |
+
import time
|
14 |
+
import datetime
|
15 |
+
import json
|
16 |
+
from pathlib import Path
|
17 |
+
|
18 |
+
import torch
|
19 |
+
import torch.nn as nn
|
20 |
+
import torch.nn.functional as F
|
21 |
+
from torch.utils.data import DataLoader
|
22 |
+
import torch.backends.cudnn as cudnn
|
23 |
+
import torch.distributed as dist
|
24 |
+
|
25 |
+
from models.blip_vqa import blip_vqa
|
26 |
+
import utils
|
27 |
+
from utils import cosine_lr_schedule
|
28 |
+
from data import create_dataset, create_sampler, create_loader
|
29 |
+
from data.vqa_dataset import vqa_collate_fn
|
30 |
+
from data.utils import save_result
|
31 |
+
|
32 |
+
|
33 |
+
def train(model, data_loader, optimizer, epoch, device):
|
34 |
+
# train
|
35 |
+
model.train()
|
36 |
+
|
37 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
38 |
+
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
|
39 |
+
metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
|
40 |
+
|
41 |
+
header = 'Train Epoch: [{}]'.format(epoch)
|
42 |
+
print_freq = 50
|
43 |
+
|
44 |
+
for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
45 |
+
image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
|
46 |
+
|
47 |
+
loss = model(image, question, answer, train=True, n=n, weights=weights)
|
48 |
+
|
49 |
+
optimizer.zero_grad()
|
50 |
+
loss.backward()
|
51 |
+
optimizer.step()
|
52 |
+
|
53 |
+
metric_logger.update(loss=loss.item())
|
54 |
+
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
|
55 |
+
|
56 |
+
# gather the stats from all processes
|
57 |
+
metric_logger.synchronize_between_processes()
|
58 |
+
print("Averaged stats:", metric_logger.global_avg())
|
59 |
+
return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
|
60 |
+
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def evaluation(model, data_loader, device, config) :
|
64 |
+
# test
|
65 |
+
model.eval()
|
66 |
+
|
67 |
+
metric_logger = utils.MetricLogger(delimiter=" ")
|
68 |
+
header = 'Generate VQA test result:'
|
69 |
+
print_freq = 50
|
70 |
+
|
71 |
+
result = []
|
72 |
+
|
73 |
+
if config['inference']=='rank':
|
74 |
+
answer_list = data_loader.dataset.answer_list
|
75 |
+
answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)
|
76 |
+
answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id
|
77 |
+
|
78 |
+
for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
|
79 |
+
image = image.to(device,non_blocking=True)
|
80 |
+
|
81 |
+
if config['inference']=='generate':
|
82 |
+
answers = model(image, question, train=False, inference='generate')
|
83 |
+
|
84 |
+
for answer, ques_id in zip(answers, question_id):
|
85 |
+
ques_id = int(ques_id.item())
|
86 |
+
result.append({"question_id":ques_id, "answer":answer})
|
87 |
+
|
88 |
+
elif config['inference']=='rank':
|
89 |
+
answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test'])
|
90 |
+
|
91 |
+
for ques_id, answer_id in zip(question_id, answer_ids):
|
92 |
+
result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]})
|
93 |
+
|
94 |
+
return result
|
95 |
+
|
96 |
+
|
97 |
+
def main(args, config):
|
98 |
+
utils.init_distributed_mode(args)
|
99 |
+
|
100 |
+
device = torch.device(args.device)
|
101 |
+
|
102 |
+
# fix the seed for reproducibility
|
103 |
+
seed = args.seed + utils.get_rank()
|
104 |
+
torch.manual_seed(seed)
|
105 |
+
np.random.seed(seed)
|
106 |
+
random.seed(seed)
|
107 |
+
cudnn.benchmark = True
|
108 |
+
|
109 |
+
#### Dataset ####
|
110 |
+
print("Creating vqa datasets")
|
111 |
+
datasets = create_dataset('vqa', config)
|
112 |
+
|
113 |
+
if args.distributed:
|
114 |
+
num_tasks = utils.get_world_size()
|
115 |
+
global_rank = utils.get_rank()
|
116 |
+
samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
|
117 |
+
else:
|
118 |
+
samplers = [None, None]
|
119 |
+
|
120 |
+
train_loader, test_loader = create_loader(datasets,samplers,
|
121 |
+
batch_size=[config['batch_size_train'],config['batch_size_test']],
|
122 |
+
num_workers=[4,4],is_trains=[True, False],
|
123 |
+
collate_fns=[vqa_collate_fn,None])
|
124 |
+
#### Model ####
|
125 |
+
print("Creating model")
|
126 |
+
model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'],
|
127 |
+
vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
|
128 |
+
|
129 |
+
model = model.to(device)
|
130 |
+
|
131 |
+
model_without_ddp = model
|
132 |
+
if args.distributed:
|
133 |
+
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
|
134 |
+
model_without_ddp = model.module
|
135 |
+
|
136 |
+
optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
|
137 |
+
|
138 |
+
best = 0
|
139 |
+
best_epoch = 0
|
140 |
+
|
141 |
+
print("Start training")
|
142 |
+
start_time = time.time()
|
143 |
+
for epoch in range(0, config['max_epoch']):
|
144 |
+
if not args.evaluate:
|
145 |
+
if args.distributed:
|
146 |
+
train_loader.sampler.set_epoch(epoch)
|
147 |
+
|
148 |
+
cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
|
149 |
+
|
150 |
+
train_stats = train(model, train_loader, optimizer, epoch, device)
|
151 |
+
|
152 |
+
else:
|
153 |
+
break
|
154 |
+
|
155 |
+
if utils.is_main_process():
|
156 |
+
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
|
157 |
+
'epoch': epoch,
|
158 |
+
}
|
159 |
+
with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
|
160 |
+
f.write(json.dumps(log_stats) + "\n")
|
161 |
+
|
162 |
+
save_obj = {
|
163 |
+
'model': model_without_ddp.state_dict(),
|
164 |
+
'optimizer': optimizer.state_dict(),
|
165 |
+
'config': config,
|
166 |
+
'epoch': epoch,
|
167 |
+
}
|
168 |
+
torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
|
169 |
+
|
170 |
+
dist.barrier()
|
171 |
+
|
172 |
+
vqa_result = evaluation(model_without_ddp, test_loader, device, config)
|
173 |
+
result_file = save_result(vqa_result, args.result_dir, 'vqa_result')
|
174 |
+
|
175 |
+
total_time = time.time() - start_time
|
176 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
177 |
+
print('Training time {}'.format(total_time_str))
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
if __name__ == '__main__':
|
182 |
+
parser = argparse.ArgumentParser()
|
183 |
+
parser.add_argument('--config', default='./configs/vqa.yaml')
|
184 |
+
parser.add_argument('--output_dir', default='output/VQA')
|
185 |
+
parser.add_argument('--evaluate', action='store_true')
|
186 |
+
parser.add_argument('--device', default='cuda')
|
187 |
+
parser.add_argument('--seed', default=42, type=int)
|
188 |
+
parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
|
189 |
+
parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
|
190 |
+
parser.add_argument('--distributed', default=True, type=bool)
|
191 |
+
args = parser.parse_args()
|
192 |
+
|
193 |
+
config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
|
194 |
+
|
195 |
+
args.result_dir = os.path.join(args.output_dir, 'result')
|
196 |
+
|
197 |
+
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
|
198 |
+
Path(args.result_dir).mkdir(parents=True, exist_ok=True)
|
199 |
+
|
200 |
+
yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
|
201 |
+
|
202 |
+
main(args, config)
|
utils.py
ADDED
@@ -0,0 +1,278 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
|
3 |
+
"""Decay the learning rate"""
|
4 |
+
lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
|
5 |
+
for param_group in optimizer.param_groups:
|
6 |
+
param_group['lr'] = lr
|
7 |
+
|
8 |
+
def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
|
9 |
+
"""Warmup the learning rate"""
|
10 |
+
lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
|
11 |
+
for param_group in optimizer.param_groups:
|
12 |
+
param_group['lr'] = lr
|
13 |
+
|
14 |
+
def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
|
15 |
+
"""Decay the learning rate"""
|
16 |
+
lr = max(min_lr, init_lr * (decay_rate**epoch))
|
17 |
+
for param_group in optimizer.param_groups:
|
18 |
+
param_group['lr'] = lr
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import io
|
22 |
+
import os
|
23 |
+
import time
|
24 |
+
from collections import defaultdict, deque
|
25 |
+
import datetime
|
26 |
+
|
27 |
+
import torch
|
28 |
+
import torch.distributed as dist
|
29 |
+
|
30 |
+
class SmoothedValue(object):
|
31 |
+
"""Track a series of values and provide access to smoothed values over a
|
32 |
+
window or the global series average.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self, window_size=20, fmt=None):
|
36 |
+
if fmt is None:
|
37 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
38 |
+
self.deque = deque(maxlen=window_size)
|
39 |
+
self.total = 0.0
|
40 |
+
self.count = 0
|
41 |
+
self.fmt = fmt
|
42 |
+
|
43 |
+
def update(self, value, n=1):
|
44 |
+
self.deque.append(value)
|
45 |
+
self.count += n
|
46 |
+
self.total += value * n
|
47 |
+
|
48 |
+
def synchronize_between_processes(self):
|
49 |
+
"""
|
50 |
+
Warning: does not synchronize the deque!
|
51 |
+
"""
|
52 |
+
if not is_dist_avail_and_initialized():
|
53 |
+
return
|
54 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
55 |
+
dist.barrier()
|
56 |
+
dist.all_reduce(t)
|
57 |
+
t = t.tolist()
|
58 |
+
self.count = int(t[0])
|
59 |
+
self.total = t[1]
|
60 |
+
|
61 |
+
@property
|
62 |
+
def median(self):
|
63 |
+
d = torch.tensor(list(self.deque))
|
64 |
+
return d.median().item()
|
65 |
+
|
66 |
+
@property
|
67 |
+
def avg(self):
|
68 |
+
d = torch.tensor(list(self.deque), dtype=torch.float32)
|
69 |
+
return d.mean().item()
|
70 |
+
|
71 |
+
@property
|
72 |
+
def global_avg(self):
|
73 |
+
return self.total / self.count
|
74 |
+
|
75 |
+
@property
|
76 |
+
def max(self):
|
77 |
+
return max(self.deque)
|
78 |
+
|
79 |
+
@property
|
80 |
+
def value(self):
|
81 |
+
return self.deque[-1]
|
82 |
+
|
83 |
+
def __str__(self):
|
84 |
+
return self.fmt.format(
|
85 |
+
median=self.median,
|
86 |
+
avg=self.avg,
|
87 |
+
global_avg=self.global_avg,
|
88 |
+
max=self.max,
|
89 |
+
value=self.value)
|
90 |
+
|
91 |
+
|
92 |
+
class MetricLogger(object):
|
93 |
+
def __init__(self, delimiter="\t"):
|
94 |
+
self.meters = defaultdict(SmoothedValue)
|
95 |
+
self.delimiter = delimiter
|
96 |
+
|
97 |
+
def update(self, **kwargs):
|
98 |
+
for k, v in kwargs.items():
|
99 |
+
if isinstance(v, torch.Tensor):
|
100 |
+
v = v.item()
|
101 |
+
assert isinstance(v, (float, int))
|
102 |
+
self.meters[k].update(v)
|
103 |
+
|
104 |
+
def __getattr__(self, attr):
|
105 |
+
if attr in self.meters:
|
106 |
+
return self.meters[attr]
|
107 |
+
if attr in self.__dict__:
|
108 |
+
return self.__dict__[attr]
|
109 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
110 |
+
type(self).__name__, attr))
|
111 |
+
|
112 |
+
def __str__(self):
|
113 |
+
loss_str = []
|
114 |
+
for name, meter in self.meters.items():
|
115 |
+
loss_str.append(
|
116 |
+
"{}: {}".format(name, str(meter))
|
117 |
+
)
|
118 |
+
return self.delimiter.join(loss_str)
|
119 |
+
|
120 |
+
def global_avg(self):
|
121 |
+
loss_str = []
|
122 |
+
for name, meter in self.meters.items():
|
123 |
+
loss_str.append(
|
124 |
+
"{}: {:.4f}".format(name, meter.global_avg)
|
125 |
+
)
|
126 |
+
return self.delimiter.join(loss_str)
|
127 |
+
|
128 |
+
def synchronize_between_processes(self):
|
129 |
+
for meter in self.meters.values():
|
130 |
+
meter.synchronize_between_processes()
|
131 |
+
|
132 |
+
def add_meter(self, name, meter):
|
133 |
+
self.meters[name] = meter
|
134 |
+
|
135 |
+
def log_every(self, iterable, print_freq, header=None):
|
136 |
+
i = 0
|
137 |
+
if not header:
|
138 |
+
header = ''
|
139 |
+
start_time = time.time()
|
140 |
+
end = time.time()
|
141 |
+
iter_time = SmoothedValue(fmt='{avg:.4f}')
|
142 |
+
data_time = SmoothedValue(fmt='{avg:.4f}')
|
143 |
+
space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
|
144 |
+
log_msg = [
|
145 |
+
header,
|
146 |
+
'[{0' + space_fmt + '}/{1}]',
|
147 |
+
'eta: {eta}',
|
148 |
+
'{meters}',
|
149 |
+
'time: {time}',
|
150 |
+
'data: {data}'
|
151 |
+
]
|
152 |
+
if torch.cuda.is_available():
|
153 |
+
log_msg.append('max mem: {memory:.0f}')
|
154 |
+
log_msg = self.delimiter.join(log_msg)
|
155 |
+
MB = 1024.0 * 1024.0
|
156 |
+
for obj in iterable:
|
157 |
+
data_time.update(time.time() - end)
|
158 |
+
yield obj
|
159 |
+
iter_time.update(time.time() - end)
|
160 |
+
if i % print_freq == 0 or i == len(iterable) - 1:
|
161 |
+
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
162 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
163 |
+
if torch.cuda.is_available():
|
164 |
+
print(log_msg.format(
|
165 |
+
i, len(iterable), eta=eta_string,
|
166 |
+
meters=str(self),
|
167 |
+
time=str(iter_time), data=str(data_time),
|
168 |
+
memory=torch.cuda.max_memory_allocated() / MB))
|
169 |
+
else:
|
170 |
+
print(log_msg.format(
|
171 |
+
i, len(iterable), eta=eta_string,
|
172 |
+
meters=str(self),
|
173 |
+
time=str(iter_time), data=str(data_time)))
|
174 |
+
i += 1
|
175 |
+
end = time.time()
|
176 |
+
total_time = time.time() - start_time
|
177 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
178 |
+
print('{} Total time: {} ({:.4f} s / it)'.format(
|
179 |
+
header, total_time_str, total_time / len(iterable)))
|
180 |
+
|
181 |
+
|
182 |
+
class AttrDict(dict):
|
183 |
+
def __init__(self, *args, **kwargs):
|
184 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
185 |
+
self.__dict__ = self
|
186 |
+
|
187 |
+
|
188 |
+
def compute_acc(logits, label, reduction='mean'):
|
189 |
+
ret = (torch.argmax(logits, dim=1) == label).float()
|
190 |
+
if reduction == 'none':
|
191 |
+
return ret.detach()
|
192 |
+
elif reduction == 'mean':
|
193 |
+
return ret.mean().item()
|
194 |
+
|
195 |
+
def compute_n_params(model, return_str=True):
|
196 |
+
tot = 0
|
197 |
+
for p in model.parameters():
|
198 |
+
w = 1
|
199 |
+
for x in p.shape:
|
200 |
+
w *= x
|
201 |
+
tot += w
|
202 |
+
if return_str:
|
203 |
+
if tot >= 1e6:
|
204 |
+
return '{:.1f}M'.format(tot / 1e6)
|
205 |
+
else:
|
206 |
+
return '{:.1f}K'.format(tot / 1e3)
|
207 |
+
else:
|
208 |
+
return tot
|
209 |
+
|
210 |
+
def setup_for_distributed(is_master):
|
211 |
+
"""
|
212 |
+
This function disables printing when not in master process
|
213 |
+
"""
|
214 |
+
import builtins as __builtin__
|
215 |
+
builtin_print = __builtin__.print
|
216 |
+
|
217 |
+
def print(*args, **kwargs):
|
218 |
+
force = kwargs.pop('force', False)
|
219 |
+
if is_master or force:
|
220 |
+
builtin_print(*args, **kwargs)
|
221 |
+
|
222 |
+
__builtin__.print = print
|
223 |
+
|
224 |
+
|
225 |
+
def is_dist_avail_and_initialized():
|
226 |
+
if not dist.is_available():
|
227 |
+
return False
|
228 |
+
if not dist.is_initialized():
|
229 |
+
return False
|
230 |
+
return True
|
231 |
+
|
232 |
+
|
233 |
+
def get_world_size():
|
234 |
+
if not is_dist_avail_and_initialized():
|
235 |
+
return 1
|
236 |
+
return dist.get_world_size()
|
237 |
+
|
238 |
+
|
239 |
+
def get_rank():
|
240 |
+
if not is_dist_avail_and_initialized():
|
241 |
+
return 0
|
242 |
+
return dist.get_rank()
|
243 |
+
|
244 |
+
|
245 |
+
def is_main_process():
|
246 |
+
return get_rank() == 0
|
247 |
+
|
248 |
+
|
249 |
+
def save_on_master(*args, **kwargs):
|
250 |
+
if is_main_process():
|
251 |
+
torch.save(*args, **kwargs)
|
252 |
+
|
253 |
+
|
254 |
+
def init_distributed_mode(args):
|
255 |
+
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
|
256 |
+
args.rank = int(os.environ["RANK"])
|
257 |
+
args.world_size = int(os.environ['WORLD_SIZE'])
|
258 |
+
args.gpu = int(os.environ['LOCAL_RANK'])
|
259 |
+
elif 'SLURM_PROCID' in os.environ:
|
260 |
+
args.rank = int(os.environ['SLURM_PROCID'])
|
261 |
+
args.gpu = args.rank % torch.cuda.device_count()
|
262 |
+
else:
|
263 |
+
print('Not using distributed mode')
|
264 |
+
args.distributed = False
|
265 |
+
return
|
266 |
+
|
267 |
+
args.distributed = True
|
268 |
+
|
269 |
+
torch.cuda.set_device(args.gpu)
|
270 |
+
args.dist_backend = 'nccl'
|
271 |
+
print('| distributed init (rank {}, word {}): {}'.format(
|
272 |
+
args.rank, args.world_size, args.dist_url), flush=True)
|
273 |
+
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
|
274 |
+
world_size=args.world_size, rank=args.rank)
|
275 |
+
torch.distributed.barrier()
|
276 |
+
setup_for_distributed(args.rank == 0)
|
277 |
+
|
278 |
+
|
walk.jpeg
ADDED