GV05 commited on
Commit
049c619
1 Parent(s): c53a443
Files changed (12) hide show
  1. LICENSE.txt +12 -0
  2. alarm.jpeg +0 -0
  3. alarm1.jpeg +0 -0
  4. app.py +36 -0
  5. eval_nocaps.py +118 -0
  6. requirements.txt +6 -0
  7. train_caption.py +206 -0
  8. train_nlvr.py +213 -0
  9. train_retrieval.py +345 -0
  10. train_vqa.py +202 -0
  11. utils.py +278 -0
  12. walk.jpeg +0 -0
LICENSE.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022, Salesforce.com, Inc.
2
+ All rights reserved.
3
+
4
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+ * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+ * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+ * Neither the name of Salesforce.com nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
11
+
12
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
alarm.jpeg ADDED
alarm1.jpeg ADDED
app.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from models.blip_vqa import blip_vqa
3
+ import torch
4
+ from torchvision import transforms
5
+ from torchvision.transforms.functional import InterpolationMode
6
+
7
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8
+ image_size = 480
9
+
10
+ transform = transforms.Compose([
11
+ transforms.Resize((image_size,image_size),interpolation=InterpolationMode.BICUBIC),
12
+ transforms.ToTensor(),
13
+ transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
14
+ ])
15
+
16
+ model_url = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_vqa_capfilt_large.pth'
17
+ model = blip_vqa(pretrained=model_url, image_size=image_size, vit='base')
18
+ model.eval()
19
+ model = model.to(device)
20
+
21
+
22
+ def pool_alarm(raw_image):
23
+ question = 'there is someone in the pool?'
24
+ image = transform(raw_image).unsqueeze(0).to(device)
25
+ with torch.no_grad():
26
+ answer = model(image, question, train=False, inference='generate')
27
+
28
+ return 'answer: ' + answer[0]
29
+
30
+
31
+ input = gr.inputs.Image(type='pil')
32
+ output = gr.outputs.Textbox()
33
+ examples = ['alarm.jpeg', 'alarm1.jpeg', 'walk.jpeg']
34
+ title = ""
35
+ description = ""
36
+ intf = gr.Interface(fn=pool_alarm, inputs=input, outputs=output, examples=examples).launch()
eval_nocaps.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip import blip_decoder
26
+ import utils
27
+ from data import create_dataset, create_sampler, create_loader
28
+ from data.utils import save_result
29
+
30
+ @torch.no_grad()
31
+ def evaluate(model, data_loader, device, config):
32
+ # evaluate
33
+ model.eval()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ header = 'Evaluation:'
37
+ print_freq = 10
38
+
39
+ result = []
40
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
41
+
42
+ image = image.to(device)
43
+
44
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
45
+ min_length=config['min_length'], repetition_penalty=1.1)
46
+
47
+ for caption, img_id in zip(captions, image_id):
48
+ result.append({"image_id": img_id.item(), "caption": caption})
49
+
50
+ return result
51
+
52
+
53
+ def main(args, config):
54
+ utils.init_distributed_mode(args)
55
+
56
+ device = torch.device(args.device)
57
+
58
+ # fix the seed for reproducibility
59
+ seed = args.seed + utils.get_rank()
60
+ torch.manual_seed(seed)
61
+ np.random.seed(seed)
62
+ random.seed(seed)
63
+ cudnn.benchmark = True
64
+
65
+ #### Dataset ####
66
+ print("Creating captioning dataset")
67
+ val_dataset, test_dataset = create_dataset('nocaps', config)
68
+
69
+ if args.distributed:
70
+ num_tasks = utils.get_world_size()
71
+ global_rank = utils.get_rank()
72
+ samplers = create_sampler([val_dataset,test_dataset], [False,False], num_tasks, global_rank)
73
+ else:
74
+ samplers = [None,None]
75
+
76
+ val_loader, test_loader = create_loader([val_dataset, test_dataset],samplers,
77
+ batch_size=[config['batch_size']]*2,num_workers=[4,4],
78
+ is_trains=[False, False], collate_fns=[None,None])
79
+
80
+ #### Model ####
81
+ print("Creating model")
82
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
83
+ prompt=config['prompt'])
84
+
85
+ model = model.to(device)
86
+
87
+ model_without_ddp = model
88
+ if args.distributed:
89
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
90
+ model_without_ddp = model.module
91
+
92
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
93
+ val_result_file = save_result(val_result, args.result_dir, 'val', remove_duplicate='image_id')
94
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
95
+ test_result_file = save_result(test_result, args.result_dir, 'test', remove_duplicate='image_id')
96
+
97
+
98
+ if __name__ == '__main__':
99
+ parser = argparse.ArgumentParser()
100
+ parser.add_argument('--config', default='./configs/nocaps.yaml')
101
+ parser.add_argument('--output_dir', default='output/NoCaps')
102
+ parser.add_argument('--device', default='cuda')
103
+ parser.add_argument('--seed', default=42, type=int)
104
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
105
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
106
+ parser.add_argument('--distributed', default=True, type=bool)
107
+ args = parser.parse_args()
108
+
109
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
110
+
111
+ args.result_dir = os.path.join(args.output_dir, 'result')
112
+
113
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
114
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
115
+
116
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
117
+
118
+ main(args, config)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ timm==0.4.12
2
+ transformers==4.15.0
3
+ fairscale==0.4.4
4
+ pycocoevalcap
5
+ torch
6
+ torchvision
train_caption.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip import blip_decoder
26
+ import utils
27
+ from utils import cosine_lr_schedule
28
+ from data import create_dataset, create_sampler, create_loader
29
+ from data.utils import save_result, coco_caption_eval
30
+
31
+ def train(model, data_loader, optimizer, epoch, device):
32
+ # train
33
+ model.train()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
37
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
38
+ header = 'Train Caption Epoch: [{}]'.format(epoch)
39
+ print_freq = 50
40
+
41
+ for i, (image, caption, _) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
42
+ image = image.to(device)
43
+
44
+ loss = model(image, caption)
45
+
46
+ optimizer.zero_grad()
47
+ loss.backward()
48
+ optimizer.step()
49
+
50
+ metric_logger.update(loss=loss.item())
51
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
52
+
53
+ # gather the stats from all processes
54
+ metric_logger.synchronize_between_processes()
55
+ print("Averaged stats:", metric_logger.global_avg())
56
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
57
+
58
+
59
+ @torch.no_grad()
60
+ def evaluate(model, data_loader, device, config):
61
+ # evaluate
62
+ model.eval()
63
+
64
+ metric_logger = utils.MetricLogger(delimiter=" ")
65
+ header = 'Caption generation:'
66
+ print_freq = 10
67
+
68
+ result = []
69
+ for image, image_id in metric_logger.log_every(data_loader, print_freq, header):
70
+
71
+ image = image.to(device)
72
+
73
+ captions = model.generate(image, sample=False, num_beams=config['num_beams'], max_length=config['max_length'],
74
+ min_length=config['min_length'])
75
+
76
+ for caption, img_id in zip(captions, image_id):
77
+ result.append({"image_id": img_id.item(), "caption": caption})
78
+
79
+ return result
80
+
81
+
82
+ def main(args, config):
83
+ utils.init_distributed_mode(args)
84
+
85
+ device = torch.device(args.device)
86
+
87
+ # fix the seed for reproducibility
88
+ seed = args.seed + utils.get_rank()
89
+ torch.manual_seed(seed)
90
+ np.random.seed(seed)
91
+ random.seed(seed)
92
+ cudnn.benchmark = True
93
+
94
+ #### Dataset ####
95
+ print("Creating captioning dataset")
96
+ train_dataset, val_dataset, test_dataset = create_dataset('caption_coco', config)
97
+
98
+ if args.distributed:
99
+ num_tasks = utils.get_world_size()
100
+ global_rank = utils.get_rank()
101
+ samplers = create_sampler([train_dataset,val_dataset,test_dataset], [True,False,False], num_tasks, global_rank)
102
+ else:
103
+ samplers = [None, None, None]
104
+
105
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
106
+ batch_size=[config['batch_size']]*3,num_workers=[4,4,4],
107
+ is_trains=[True, False, False], collate_fns=[None,None,None])
108
+
109
+ #### Model ####
110
+ print("Creating model")
111
+ model = blip_decoder(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
112
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
113
+ prompt=config['prompt'])
114
+
115
+ model = model.to(device)
116
+
117
+ model_without_ddp = model
118
+ if args.distributed:
119
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
120
+ model_without_ddp = model.module
121
+
122
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
123
+
124
+ best = 0
125
+ best_epoch = 0
126
+
127
+ print("Start training")
128
+ start_time = time.time()
129
+ for epoch in range(0, config['max_epoch']):
130
+ if not args.evaluate:
131
+ if args.distributed:
132
+ train_loader.sampler.set_epoch(epoch)
133
+
134
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
135
+
136
+ train_stats = train(model, train_loader, optimizer, epoch, device)
137
+
138
+ val_result = evaluate(model_without_ddp, val_loader, device, config)
139
+ val_result_file = save_result(val_result, args.result_dir, 'val_epoch%d'%epoch, remove_duplicate='image_id')
140
+
141
+ test_result = evaluate(model_without_ddp, test_loader, device, config)
142
+ test_result_file = save_result(test_result, args.result_dir, 'test_epoch%d'%epoch, remove_duplicate='image_id')
143
+
144
+ if utils.is_main_process():
145
+ coco_val = coco_caption_eval(config['coco_gt_root'],val_result_file,'val')
146
+ coco_test = coco_caption_eval(config['coco_gt_root'],test_result_file,'test')
147
+
148
+ if args.evaluate:
149
+ log_stats = {**{f'val_{k}': v for k, v in coco_val.eval.items()},
150
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
151
+ }
152
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
153
+ f.write(json.dumps(log_stats) + "\n")
154
+ else:
155
+ save_obj = {
156
+ 'model': model_without_ddp.state_dict(),
157
+ 'optimizer': optimizer.state_dict(),
158
+ 'config': config,
159
+ 'epoch': epoch,
160
+ }
161
+
162
+ if coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4'] > best:
163
+ best = coco_val.eval['CIDEr'] + coco_val.eval['Bleu_4']
164
+ best_epoch = epoch
165
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
166
+
167
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
168
+ **{f'val_{k}': v for k, v in coco_val.eval.items()},
169
+ **{f'test_{k}': v for k, v in coco_test.eval.items()},
170
+ 'epoch': epoch,
171
+ 'best_epoch': best_epoch,
172
+ }
173
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
174
+ f.write(json.dumps(log_stats) + "\n")
175
+
176
+ if args.evaluate:
177
+ break
178
+ dist.barrier()
179
+
180
+ total_time = time.time() - start_time
181
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
182
+ print('Training time {}'.format(total_time_str))
183
+
184
+
185
+ if __name__ == '__main__':
186
+ parser = argparse.ArgumentParser()
187
+ parser.add_argument('--config', default='./configs/caption_coco.yaml')
188
+ parser.add_argument('--output_dir', default='output/Caption_coco')
189
+ parser.add_argument('--evaluate', action='store_true')
190
+ parser.add_argument('--device', default='cuda')
191
+ parser.add_argument('--seed', default=42, type=int)
192
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
193
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
194
+ parser.add_argument('--distributed', default=True, type=bool)
195
+ args = parser.parse_args()
196
+
197
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
198
+
199
+ args.result_dir = os.path.join(args.output_dir, 'result')
200
+
201
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
202
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
203
+
204
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
205
+
206
+ main(args, config)
train_nlvr.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+ import json
18
+ import pickle
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.utils.data import DataLoader
24
+ import torch.backends.cudnn as cudnn
25
+ import torch.distributed as dist
26
+
27
+ from models.blip_nlvr import blip_nlvr
28
+
29
+ import utils
30
+ from utils import cosine_lr_schedule, warmup_lr_schedule
31
+ from data import create_dataset, create_sampler, create_loader
32
+
33
+ def train(model, data_loader, optimizer, epoch, device, config):
34
+ # train
35
+ model.train()
36
+
37
+ metric_logger = utils.MetricLogger(delimiter=" ")
38
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=50, fmt='{value:.6f}'))
39
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=50, fmt='{value:.4f}'))
40
+
41
+ header = 'Train Epoch: [{}]'.format(epoch)
42
+ print_freq = 50
43
+ step_size = 10
44
+
45
+ for i,(image0, image1, text, targets) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
46
+
47
+ images = torch.cat([image0, image1], dim=0)
48
+ images, targets = images.to(device), targets.to(device)
49
+
50
+ loss = model(images, text, targets=targets, train=True)
51
+
52
+ optimizer.zero_grad()
53
+ loss.backward()
54
+ optimizer.step()
55
+
56
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
57
+ metric_logger.update(loss=loss.item())
58
+
59
+ # gather the stats from all processes
60
+ metric_logger.synchronize_between_processes()
61
+ print("Averaged stats:", metric_logger.global_avg())
62
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
63
+
64
+
65
+ @torch.no_grad()
66
+ def evaluate(model, data_loader, device, config):
67
+ # test
68
+ model.eval()
69
+
70
+ metric_logger = utils.MetricLogger(delimiter=" ")
71
+
72
+ header = 'Evaluation:'
73
+ print_freq = 50
74
+
75
+ for image0, image1, text, targets in metric_logger.log_every(data_loader, print_freq, header):
76
+ images = torch.cat([image0, image1], dim=0)
77
+ images, targets = images.to(device), targets.to(device)
78
+
79
+ prediction = model(images, text, targets=targets, train=False)
80
+
81
+ _, pred_class = prediction.max(1)
82
+ accuracy = (targets==pred_class).sum() / targets.size(0)
83
+
84
+ metric_logger.meters['acc'].update(accuracy.item(), n=image0.size(0))
85
+
86
+ # gather the stats from all processes
87
+ metric_logger.synchronize_between_processes()
88
+
89
+ print("Averaged stats:", metric_logger.global_avg())
90
+ return {k: "{:.4f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
91
+
92
+
93
+
94
+ def main(args, config):
95
+ utils.init_distributed_mode(args)
96
+
97
+ device = torch.device(args.device)
98
+
99
+ # fix the seed for reproducibility
100
+ seed = args.seed + utils.get_rank()
101
+ torch.manual_seed(seed)
102
+ np.random.seed(seed)
103
+ random.seed(seed)
104
+ cudnn.benchmark = True
105
+
106
+ #### Dataset ####
107
+ print("Creating dataset")
108
+ datasets = create_dataset('nlvr', config)
109
+
110
+ if args.distributed:
111
+ num_tasks = utils.get_world_size()
112
+ global_rank = utils.get_rank()
113
+ samplers = create_sampler(datasets, [True,False,False], num_tasks, global_rank)
114
+ else:
115
+ samplers = [None, None, None]
116
+
117
+ batch_size=[config['batch_size_train'],config['batch_size_test'],config['batch_size_test']]
118
+ train_loader, val_loader, test_loader = create_loader(datasets,samplers,batch_size=batch_size,
119
+ num_workers=[4,4,4],is_trains=[True,False,False],
120
+ collate_fns=[None,None,None])
121
+
122
+ #### Model ####
123
+ print("Creating model")
124
+ model = blip_nlvr(pretrained=config['pretrained'], image_size=config['image_size'],
125
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
126
+
127
+ model = model.to(device)
128
+
129
+ model_without_ddp = model
130
+ if args.distributed:
131
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
132
+ model_without_ddp = model.module
133
+
134
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
135
+
136
+ print("Start training")
137
+ start_time = time.time()
138
+ best = 0
139
+ best_epoch = 0
140
+
141
+ for epoch in range(0, config['max_epoch']):
142
+ if not args.evaluate:
143
+ if args.distributed:
144
+ train_loader.sampler.set_epoch(epoch)
145
+
146
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
147
+
148
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
149
+
150
+ val_stats = evaluate(model, val_loader, device, config)
151
+ test_stats = evaluate(model, test_loader, device, config)
152
+
153
+ if utils.is_main_process():
154
+ if args.evaluate:
155
+ log_stats = {**{f'val_{k}': v for k, v in val_stats.items()},
156
+ **{f'test_{k}': v for k, v in test_stats.items()},
157
+ }
158
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
159
+ f.write(json.dumps(log_stats) + "\n")
160
+
161
+ else:
162
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
163
+ **{f'val_{k}': v for k, v in val_stats.items()},
164
+ **{f'test_{k}': v for k, v in test_stats.items()},
165
+ 'epoch': epoch,
166
+ }
167
+
168
+ if float(val_stats['acc'])>best:
169
+ save_obj = {
170
+ 'model': model_without_ddp.state_dict(),
171
+ 'optimizer': optimizer.state_dict(),
172
+ 'config': config,
173
+ 'epoch': epoch,
174
+ }
175
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
176
+ best = float(val_stats['acc'])
177
+ best_epoch = epoch
178
+
179
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
180
+ f.write(json.dumps(log_stats) + "\n")
181
+ if args.evaluate:
182
+ break
183
+
184
+ dist.barrier()
185
+
186
+ if utils.is_main_process():
187
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
188
+ f.write("best epoch: %d"%best_epoch)
189
+
190
+ total_time = time.time() - start_time
191
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
192
+ print('Training time {}'.format(total_time_str))
193
+
194
+
195
+ if __name__ == '__main__':
196
+ parser = argparse.ArgumentParser()
197
+ parser.add_argument('--config', default='./configs/nlvr.yaml')
198
+ parser.add_argument('--output_dir', default='output/NLVR')
199
+ parser.add_argument('--evaluate', action='store_true')
200
+ parser.add_argument('--device', default='cuda')
201
+ parser.add_argument('--seed', default=42, type=int)
202
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
203
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
204
+ parser.add_argument('--distributed', default=True, type=bool)
205
+ args = parser.parse_args()
206
+
207
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
208
+
209
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
210
+
211
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
212
+
213
+ main(args, config)
train_retrieval.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ import torch.backends.cudnn as cudnn
22
+ import torch.distributed as dist
23
+ from torch.utils.data import DataLoader
24
+
25
+ from models.blip_retrieval import blip_retrieval
26
+ import utils
27
+ from utils import cosine_lr_schedule
28
+ from data import create_dataset, create_sampler, create_loader
29
+
30
+
31
+ def train(model, data_loader, optimizer, epoch, device, config):
32
+ # train
33
+ model.train()
34
+
35
+ metric_logger = utils.MetricLogger(delimiter=" ")
36
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
37
+ metric_logger.add_meter('loss_itm', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
38
+ metric_logger.add_meter('loss_ita', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
39
+ header = 'Train Epoch: [{}]'.format(epoch)
40
+ print_freq = 50
41
+
42
+ for i,(image, caption, idx) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
43
+ image = image.to(device,non_blocking=True)
44
+ idx = idx.to(device,non_blocking=True)
45
+
46
+ if epoch>0:
47
+ alpha = config['alpha']
48
+ else:
49
+ alpha = config['alpha']*min(1,i/len(data_loader))
50
+
51
+ loss_ita, loss_itm = model(image, caption, alpha=alpha, idx=idx)
52
+ loss = loss_ita + loss_itm
53
+
54
+ optimizer.zero_grad()
55
+ loss.backward()
56
+ optimizer.step()
57
+
58
+ metric_logger.update(loss_itm=loss_itm.item())
59
+ metric_logger.update(loss_ita=loss_ita.item())
60
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
61
+
62
+ # gather the stats from all processes
63
+ metric_logger.synchronize_between_processes()
64
+ print("Averaged stats:", metric_logger.global_avg())
65
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
66
+
67
+
68
+ @torch.no_grad()
69
+ def evaluation(model, data_loader, device, config):
70
+ # test
71
+ model.eval()
72
+
73
+ metric_logger = utils.MetricLogger(delimiter=" ")
74
+ header = 'Evaluation:'
75
+
76
+ print('Computing features for evaluation...')
77
+ start_time = time.time()
78
+
79
+ texts = data_loader.dataset.text
80
+ num_text = len(texts)
81
+ text_bs = 256
82
+ text_ids = []
83
+ text_embeds = []
84
+ text_atts = []
85
+ for i in range(0, num_text, text_bs):
86
+ text = texts[i: min(num_text, i+text_bs)]
87
+ text_input = model.tokenizer(text, padding='max_length', truncation=True, max_length=35, return_tensors="pt").to(device)
88
+ text_output = model.text_encoder(text_input.input_ids, attention_mask = text_input.attention_mask, mode='text')
89
+ text_embed = F.normalize(model.text_proj(text_output.last_hidden_state[:,0,:]))
90
+ text_embeds.append(text_embed)
91
+ text_ids.append(text_input.input_ids)
92
+ text_atts.append(text_input.attention_mask)
93
+
94
+ text_embeds = torch.cat(text_embeds,dim=0)
95
+ text_ids = torch.cat(text_ids,dim=0)
96
+ text_atts = torch.cat(text_atts,dim=0)
97
+ text_ids[:,0] = model.tokenizer.enc_token_id
98
+
99
+ image_feats = []
100
+ image_embeds = []
101
+ for image, img_id in data_loader:
102
+ image = image.to(device)
103
+ image_feat = model.visual_encoder(image)
104
+ image_embed = model.vision_proj(image_feat[:,0,:])
105
+ image_embed = F.normalize(image_embed,dim=-1)
106
+
107
+ image_feats.append(image_feat.cpu())
108
+ image_embeds.append(image_embed)
109
+
110
+ image_feats = torch.cat(image_feats,dim=0)
111
+ image_embeds = torch.cat(image_embeds,dim=0)
112
+
113
+ sims_matrix = image_embeds @ text_embeds.t()
114
+ score_matrix_i2t = torch.full((len(data_loader.dataset.image),len(texts)),-100.0).to(device)
115
+
116
+ num_tasks = utils.get_world_size()
117
+ rank = utils.get_rank()
118
+ step = sims_matrix.size(0)//num_tasks + 1
119
+ start = rank*step
120
+ end = min(sims_matrix.size(0),start+step)
121
+
122
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
123
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
124
+
125
+ encoder_output = image_feats[start+i].repeat(config['k_test'],1,1).to(device)
126
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
127
+ output = model.text_encoder(text_ids[topk_idx],
128
+ attention_mask = text_atts[topk_idx],
129
+ encoder_hidden_states = encoder_output,
130
+ encoder_attention_mask = encoder_att,
131
+ return_dict = True,
132
+ )
133
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
134
+ score_matrix_i2t[start+i,topk_idx] = score + topk_sim
135
+
136
+ sims_matrix = sims_matrix.t()
137
+ score_matrix_t2i = torch.full((len(texts),len(data_loader.dataset.image)),-100.0).to(device)
138
+
139
+ step = sims_matrix.size(0)//num_tasks + 1
140
+ start = rank*step
141
+ end = min(sims_matrix.size(0),start+step)
142
+
143
+ for i,sims in enumerate(metric_logger.log_every(sims_matrix[start:end], 50, header)):
144
+
145
+ topk_sim, topk_idx = sims.topk(k=config['k_test'], dim=0)
146
+ encoder_output = image_feats[topk_idx].to(device)
147
+ encoder_att = torch.ones(encoder_output.size()[:-1],dtype=torch.long).to(device)
148
+ output = model.text_encoder(text_ids[start+i].repeat(config['k_test'],1),
149
+ attention_mask = text_atts[start+i].repeat(config['k_test'],1),
150
+ encoder_hidden_states = encoder_output,
151
+ encoder_attention_mask = encoder_att,
152
+ return_dict = True,
153
+ )
154
+ score = model.itm_head(output.last_hidden_state[:,0,:])[:,1]
155
+ score_matrix_t2i[start+i,topk_idx] = score + topk_sim
156
+
157
+ if args.distributed:
158
+ dist.barrier()
159
+ torch.distributed.all_reduce(score_matrix_i2t, op=torch.distributed.ReduceOp.SUM)
160
+ torch.distributed.all_reduce(score_matrix_t2i, op=torch.distributed.ReduceOp.SUM)
161
+
162
+ total_time = time.time() - start_time
163
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
164
+ print('Evaluation time {}'.format(total_time_str))
165
+
166
+ return score_matrix_i2t.cpu().numpy(), score_matrix_t2i.cpu().numpy()
167
+
168
+
169
+
170
+ @torch.no_grad()
171
+ def itm_eval(scores_i2t, scores_t2i, txt2img, img2txt):
172
+
173
+ #Images->Text
174
+ ranks = np.zeros(scores_i2t.shape[0])
175
+ for index,score in enumerate(scores_i2t):
176
+ inds = np.argsort(score)[::-1]
177
+ # Score
178
+ rank = 1e20
179
+ for i in img2txt[index]:
180
+ tmp = np.where(inds == i)[0][0]
181
+ if tmp < rank:
182
+ rank = tmp
183
+ ranks[index] = rank
184
+
185
+ # Compute metrics
186
+ tr1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
187
+ tr5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
188
+ tr10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
189
+
190
+ #Text->Images
191
+ ranks = np.zeros(scores_t2i.shape[0])
192
+
193
+ for index,score in enumerate(scores_t2i):
194
+ inds = np.argsort(score)[::-1]
195
+ ranks[index] = np.where(inds == txt2img[index])[0][0]
196
+
197
+ # Compute metrics
198
+ ir1 = 100.0 * len(np.where(ranks < 1)[0]) / len(ranks)
199
+ ir5 = 100.0 * len(np.where(ranks < 5)[0]) / len(ranks)
200
+ ir10 = 100.0 * len(np.where(ranks < 10)[0]) / len(ranks)
201
+
202
+ tr_mean = (tr1 + tr5 + tr10) / 3
203
+ ir_mean = (ir1 + ir5 + ir10) / 3
204
+ r_mean = (tr_mean + ir_mean) / 2
205
+
206
+ eval_result = {'txt_r1': tr1,
207
+ 'txt_r5': tr5,
208
+ 'txt_r10': tr10,
209
+ 'txt_r_mean': tr_mean,
210
+ 'img_r1': ir1,
211
+ 'img_r5': ir5,
212
+ 'img_r10': ir10,
213
+ 'img_r_mean': ir_mean,
214
+ 'r_mean': r_mean}
215
+ return eval_result
216
+
217
+
218
+ def main(args, config):
219
+ utils.init_distributed_mode(args)
220
+
221
+ device = torch.device(args.device)
222
+
223
+ # fix the seed for reproducibility
224
+ seed = args.seed + utils.get_rank()
225
+ torch.manual_seed(seed)
226
+ np.random.seed(seed)
227
+ random.seed(seed)
228
+ cudnn.benchmark = True
229
+
230
+ #### Dataset ####
231
+ print("Creating retrieval dataset")
232
+ train_dataset, val_dataset, test_dataset = create_dataset('retrieval_%s'%config['dataset'], config)
233
+
234
+ if args.distributed:
235
+ num_tasks = utils.get_world_size()
236
+ global_rank = utils.get_rank()
237
+ samplers = create_sampler([train_dataset], [True], num_tasks, global_rank) + [None, None]
238
+ else:
239
+ samplers = [None, None, None]
240
+
241
+ train_loader, val_loader, test_loader = create_loader([train_dataset, val_dataset, test_dataset],samplers,
242
+ batch_size=[config['batch_size_train']]+[config['batch_size_test']]*2,
243
+ num_workers=[4,4,4],
244
+ is_trains=[True, False, False],
245
+ collate_fns=[None,None,None])
246
+
247
+
248
+ #### Model ####
249
+ print("Creating model")
250
+ model = blip_retrieval(pretrained=config['pretrained'], image_size=config['image_size'], vit=config['vit'],
251
+ vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'],
252
+ queue_size=config['queue_size'], negative_all_rank=config['negative_all_rank'])
253
+
254
+ model = model.to(device)
255
+
256
+ model_without_ddp = model
257
+ if args.distributed:
258
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
259
+ model_without_ddp = model.module
260
+
261
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
262
+
263
+ best = 0
264
+ best_epoch = 0
265
+
266
+ print("Start training")
267
+ start_time = time.time()
268
+
269
+ for epoch in range(0, config['max_epoch']):
270
+ if not args.evaluate:
271
+ if args.distributed:
272
+ train_loader.sampler.set_epoch(epoch)
273
+
274
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
275
+
276
+ train_stats = train(model, train_loader, optimizer, epoch, device, config)
277
+
278
+ score_val_i2t, score_val_t2i, = evaluation(model_without_ddp, val_loader, device, config)
279
+ score_test_i2t, score_test_t2i = evaluation(model_without_ddp, test_loader, device, config)
280
+
281
+ if utils.is_main_process():
282
+
283
+ val_result = itm_eval(score_val_i2t, score_val_t2i, val_loader.dataset.txt2img, val_loader.dataset.img2txt)
284
+ print(val_result)
285
+
286
+ if val_result['r_mean']>best:
287
+ save_obj = {
288
+ 'model': model_without_ddp.state_dict(),
289
+ 'optimizer': optimizer.state_dict(),
290
+ 'config': config,
291
+ 'epoch': epoch,
292
+ }
293
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_best.pth'))
294
+ best = val_result['r_mean']
295
+ best_epoch = epoch
296
+
297
+ test_result = itm_eval(score_test_i2t, score_test_t2i, test_loader.dataset.txt2img, test_loader.dataset.img2txt)
298
+ print(test_result)
299
+
300
+ if args.evaluate:
301
+ log_stats = {**{f'val_{k}': v for k, v in val_result.items()},
302
+ **{f'test_{k}': v for k, v in test_result.items()},
303
+ }
304
+ with open(os.path.join(args.output_dir, "evaluate.txt"),"a") as f:
305
+ f.write(json.dumps(log_stats) + "\n")
306
+ else:
307
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
308
+ **{f'val_{k}': v for k, v in val_result.items()},
309
+ **{f'test_{k}': v for k, v in test_result.items()},
310
+ 'epoch': epoch,
311
+ 'best_epoch': best_epoch,
312
+ }
313
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
314
+ f.write(json.dumps(log_stats) + "\n")
315
+
316
+ if args.evaluate:
317
+ break
318
+
319
+ dist.barrier()
320
+ torch.cuda.empty_cache()
321
+
322
+ total_time = time.time() - start_time
323
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
324
+ print('Training time {}'.format(total_time_str))
325
+
326
+
327
+ if __name__ == '__main__':
328
+ parser = argparse.ArgumentParser()
329
+ parser.add_argument('--config', default='./configs/retrieval_flickr.yaml')
330
+ parser.add_argument('--output_dir', default='output/Retrieval_flickr')
331
+ parser.add_argument('--evaluate', action='store_true')
332
+ parser.add_argument('--device', default='cuda')
333
+ parser.add_argument('--seed', default=42, type=int)
334
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
335
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
336
+ parser.add_argument('--distributed', default=True, type=bool)
337
+ args = parser.parse_args()
338
+
339
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
340
+
341
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
342
+
343
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
344
+
345
+ main(args, config)
train_vqa.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ * Copyright (c) 2022, salesforce.com, inc.
3
+ * All rights reserved.
4
+ * SPDX-License-Identifier: BSD-3-Clause
5
+ * For full license text, see LICENSE.txt file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6
+ * By Junnan Li
7
+ '''
8
+ import argparse
9
+ import os
10
+ import ruamel_yaml as yaml
11
+ import numpy as np
12
+ import random
13
+ import time
14
+ import datetime
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from torch.utils.data import DataLoader
22
+ import torch.backends.cudnn as cudnn
23
+ import torch.distributed as dist
24
+
25
+ from models.blip_vqa import blip_vqa
26
+ import utils
27
+ from utils import cosine_lr_schedule
28
+ from data import create_dataset, create_sampler, create_loader
29
+ from data.vqa_dataset import vqa_collate_fn
30
+ from data.utils import save_result
31
+
32
+
33
+ def train(model, data_loader, optimizer, epoch, device):
34
+ # train
35
+ model.train()
36
+
37
+ metric_logger = utils.MetricLogger(delimiter=" ")
38
+ metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
39
+ metric_logger.add_meter('loss', utils.SmoothedValue(window_size=1, fmt='{value:.4f}'))
40
+
41
+ header = 'Train Epoch: [{}]'.format(epoch)
42
+ print_freq = 50
43
+
44
+ for i,(image, question, answer, weights, n) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
45
+ image, weights = image.to(device,non_blocking=True), weights.to(device,non_blocking=True)
46
+
47
+ loss = model(image, question, answer, train=True, n=n, weights=weights)
48
+
49
+ optimizer.zero_grad()
50
+ loss.backward()
51
+ optimizer.step()
52
+
53
+ metric_logger.update(loss=loss.item())
54
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
55
+
56
+ # gather the stats from all processes
57
+ metric_logger.synchronize_between_processes()
58
+ print("Averaged stats:", metric_logger.global_avg())
59
+ return {k: "{:.3f}".format(meter.global_avg) for k, meter in metric_logger.meters.items()}
60
+
61
+
62
+ @torch.no_grad()
63
+ def evaluation(model, data_loader, device, config) :
64
+ # test
65
+ model.eval()
66
+
67
+ metric_logger = utils.MetricLogger(delimiter=" ")
68
+ header = 'Generate VQA test result:'
69
+ print_freq = 50
70
+
71
+ result = []
72
+
73
+ if config['inference']=='rank':
74
+ answer_list = data_loader.dataset.answer_list
75
+ answer_candidates = model.tokenizer(answer_list, padding='longest', return_tensors='pt').to(device)
76
+ answer_candidates.input_ids[:,0] = model.tokenizer.bos_token_id
77
+
78
+ for n, (image, question, question_id) in enumerate(metric_logger.log_every(data_loader, print_freq, header)):
79
+ image = image.to(device,non_blocking=True)
80
+
81
+ if config['inference']=='generate':
82
+ answers = model(image, question, train=False, inference='generate')
83
+
84
+ for answer, ques_id in zip(answers, question_id):
85
+ ques_id = int(ques_id.item())
86
+ result.append({"question_id":ques_id, "answer":answer})
87
+
88
+ elif config['inference']=='rank':
89
+ answer_ids = model(image, question, answer_candidates, train=False, inference='rank', k_test=config['k_test'])
90
+
91
+ for ques_id, answer_id in zip(question_id, answer_ids):
92
+ result.append({"question_id":int(ques_id.item()), "answer":answer_list[answer_id]})
93
+
94
+ return result
95
+
96
+
97
+ def main(args, config):
98
+ utils.init_distributed_mode(args)
99
+
100
+ device = torch.device(args.device)
101
+
102
+ # fix the seed for reproducibility
103
+ seed = args.seed + utils.get_rank()
104
+ torch.manual_seed(seed)
105
+ np.random.seed(seed)
106
+ random.seed(seed)
107
+ cudnn.benchmark = True
108
+
109
+ #### Dataset ####
110
+ print("Creating vqa datasets")
111
+ datasets = create_dataset('vqa', config)
112
+
113
+ if args.distributed:
114
+ num_tasks = utils.get_world_size()
115
+ global_rank = utils.get_rank()
116
+ samplers = create_sampler(datasets, [True, False], num_tasks, global_rank)
117
+ else:
118
+ samplers = [None, None]
119
+
120
+ train_loader, test_loader = create_loader(datasets,samplers,
121
+ batch_size=[config['batch_size_train'],config['batch_size_test']],
122
+ num_workers=[4,4],is_trains=[True, False],
123
+ collate_fns=[vqa_collate_fn,None])
124
+ #### Model ####
125
+ print("Creating model")
126
+ model = blip_vqa(pretrained=config['pretrained'], image_size=config['image_size'],
127
+ vit=config['vit'], vit_grad_ckpt=config['vit_grad_ckpt'], vit_ckpt_layer=config['vit_ckpt_layer'])
128
+
129
+ model = model.to(device)
130
+
131
+ model_without_ddp = model
132
+ if args.distributed:
133
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
134
+ model_without_ddp = model.module
135
+
136
+ optimizer = torch.optim.AdamW(params=model.parameters(), lr=config['init_lr'], weight_decay=config['weight_decay'])
137
+
138
+ best = 0
139
+ best_epoch = 0
140
+
141
+ print("Start training")
142
+ start_time = time.time()
143
+ for epoch in range(0, config['max_epoch']):
144
+ if not args.evaluate:
145
+ if args.distributed:
146
+ train_loader.sampler.set_epoch(epoch)
147
+
148
+ cosine_lr_schedule(optimizer, epoch, config['max_epoch'], config['init_lr'], config['min_lr'])
149
+
150
+ train_stats = train(model, train_loader, optimizer, epoch, device)
151
+
152
+ else:
153
+ break
154
+
155
+ if utils.is_main_process():
156
+ log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
157
+ 'epoch': epoch,
158
+ }
159
+ with open(os.path.join(args.output_dir, "log.txt"),"a") as f:
160
+ f.write(json.dumps(log_stats) + "\n")
161
+
162
+ save_obj = {
163
+ 'model': model_without_ddp.state_dict(),
164
+ 'optimizer': optimizer.state_dict(),
165
+ 'config': config,
166
+ 'epoch': epoch,
167
+ }
168
+ torch.save(save_obj, os.path.join(args.output_dir, 'checkpoint_%02d.pth'%epoch))
169
+
170
+ dist.barrier()
171
+
172
+ vqa_result = evaluation(model_without_ddp, test_loader, device, config)
173
+ result_file = save_result(vqa_result, args.result_dir, 'vqa_result')
174
+
175
+ total_time = time.time() - start_time
176
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
177
+ print('Training time {}'.format(total_time_str))
178
+
179
+
180
+
181
+ if __name__ == '__main__':
182
+ parser = argparse.ArgumentParser()
183
+ parser.add_argument('--config', default='./configs/vqa.yaml')
184
+ parser.add_argument('--output_dir', default='output/VQA')
185
+ parser.add_argument('--evaluate', action='store_true')
186
+ parser.add_argument('--device', default='cuda')
187
+ parser.add_argument('--seed', default=42, type=int)
188
+ parser.add_argument('--world_size', default=1, type=int, help='number of distributed processes')
189
+ parser.add_argument('--dist_url', default='env://', help='url used to set up distributed training')
190
+ parser.add_argument('--distributed', default=True, type=bool)
191
+ args = parser.parse_args()
192
+
193
+ config = yaml.load(open(args.config, 'r'), Loader=yaml.Loader)
194
+
195
+ args.result_dir = os.path.join(args.output_dir, 'result')
196
+
197
+ Path(args.output_dir).mkdir(parents=True, exist_ok=True)
198
+ Path(args.result_dir).mkdir(parents=True, exist_ok=True)
199
+
200
+ yaml.dump(config, open(os.path.join(args.output_dir, 'config.yaml'), 'w'))
201
+
202
+ main(args, config)
utils.py ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ def cosine_lr_schedule(optimizer, epoch, max_epoch, init_lr, min_lr):
3
+ """Decay the learning rate"""
4
+ lr = (init_lr - min_lr) * 0.5 * (1. + math.cos(math.pi * epoch / max_epoch)) + min_lr
5
+ for param_group in optimizer.param_groups:
6
+ param_group['lr'] = lr
7
+
8
+ def warmup_lr_schedule(optimizer, step, max_step, init_lr, max_lr):
9
+ """Warmup the learning rate"""
10
+ lr = min(max_lr, init_lr + (max_lr - init_lr) * step / max_step)
11
+ for param_group in optimizer.param_groups:
12
+ param_group['lr'] = lr
13
+
14
+ def step_lr_schedule(optimizer, epoch, init_lr, min_lr, decay_rate):
15
+ """Decay the learning rate"""
16
+ lr = max(min_lr, init_lr * (decay_rate**epoch))
17
+ for param_group in optimizer.param_groups:
18
+ param_group['lr'] = lr
19
+
20
+ import numpy as np
21
+ import io
22
+ import os
23
+ import time
24
+ from collections import defaultdict, deque
25
+ import datetime
26
+
27
+ import torch
28
+ import torch.distributed as dist
29
+
30
+ class SmoothedValue(object):
31
+ """Track a series of values and provide access to smoothed values over a
32
+ window or the global series average.
33
+ """
34
+
35
+ def __init__(self, window_size=20, fmt=None):
36
+ if fmt is None:
37
+ fmt = "{median:.4f} ({global_avg:.4f})"
38
+ self.deque = deque(maxlen=window_size)
39
+ self.total = 0.0
40
+ self.count = 0
41
+ self.fmt = fmt
42
+
43
+ def update(self, value, n=1):
44
+ self.deque.append(value)
45
+ self.count += n
46
+ self.total += value * n
47
+
48
+ def synchronize_between_processes(self):
49
+ """
50
+ Warning: does not synchronize the deque!
51
+ """
52
+ if not is_dist_avail_and_initialized():
53
+ return
54
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
55
+ dist.barrier()
56
+ dist.all_reduce(t)
57
+ t = t.tolist()
58
+ self.count = int(t[0])
59
+ self.total = t[1]
60
+
61
+ @property
62
+ def median(self):
63
+ d = torch.tensor(list(self.deque))
64
+ return d.median().item()
65
+
66
+ @property
67
+ def avg(self):
68
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
69
+ return d.mean().item()
70
+
71
+ @property
72
+ def global_avg(self):
73
+ return self.total / self.count
74
+
75
+ @property
76
+ def max(self):
77
+ return max(self.deque)
78
+
79
+ @property
80
+ def value(self):
81
+ return self.deque[-1]
82
+
83
+ def __str__(self):
84
+ return self.fmt.format(
85
+ median=self.median,
86
+ avg=self.avg,
87
+ global_avg=self.global_avg,
88
+ max=self.max,
89
+ value=self.value)
90
+
91
+
92
+ class MetricLogger(object):
93
+ def __init__(self, delimiter="\t"):
94
+ self.meters = defaultdict(SmoothedValue)
95
+ self.delimiter = delimiter
96
+
97
+ def update(self, **kwargs):
98
+ for k, v in kwargs.items():
99
+ if isinstance(v, torch.Tensor):
100
+ v = v.item()
101
+ assert isinstance(v, (float, int))
102
+ self.meters[k].update(v)
103
+
104
+ def __getattr__(self, attr):
105
+ if attr in self.meters:
106
+ return self.meters[attr]
107
+ if attr in self.__dict__:
108
+ return self.__dict__[attr]
109
+ raise AttributeError("'{}' object has no attribute '{}'".format(
110
+ type(self).__name__, attr))
111
+
112
+ def __str__(self):
113
+ loss_str = []
114
+ for name, meter in self.meters.items():
115
+ loss_str.append(
116
+ "{}: {}".format(name, str(meter))
117
+ )
118
+ return self.delimiter.join(loss_str)
119
+
120
+ def global_avg(self):
121
+ loss_str = []
122
+ for name, meter in self.meters.items():
123
+ loss_str.append(
124
+ "{}: {:.4f}".format(name, meter.global_avg)
125
+ )
126
+ return self.delimiter.join(loss_str)
127
+
128
+ def synchronize_between_processes(self):
129
+ for meter in self.meters.values():
130
+ meter.synchronize_between_processes()
131
+
132
+ def add_meter(self, name, meter):
133
+ self.meters[name] = meter
134
+
135
+ def log_every(self, iterable, print_freq, header=None):
136
+ i = 0
137
+ if not header:
138
+ header = ''
139
+ start_time = time.time()
140
+ end = time.time()
141
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
142
+ data_time = SmoothedValue(fmt='{avg:.4f}')
143
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
144
+ log_msg = [
145
+ header,
146
+ '[{0' + space_fmt + '}/{1}]',
147
+ 'eta: {eta}',
148
+ '{meters}',
149
+ 'time: {time}',
150
+ 'data: {data}'
151
+ ]
152
+ if torch.cuda.is_available():
153
+ log_msg.append('max mem: {memory:.0f}')
154
+ log_msg = self.delimiter.join(log_msg)
155
+ MB = 1024.0 * 1024.0
156
+ for obj in iterable:
157
+ data_time.update(time.time() - end)
158
+ yield obj
159
+ iter_time.update(time.time() - end)
160
+ if i % print_freq == 0 or i == len(iterable) - 1:
161
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
162
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
163
+ if torch.cuda.is_available():
164
+ print(log_msg.format(
165
+ i, len(iterable), eta=eta_string,
166
+ meters=str(self),
167
+ time=str(iter_time), data=str(data_time),
168
+ memory=torch.cuda.max_memory_allocated() / MB))
169
+ else:
170
+ print(log_msg.format(
171
+ i, len(iterable), eta=eta_string,
172
+ meters=str(self),
173
+ time=str(iter_time), data=str(data_time)))
174
+ i += 1
175
+ end = time.time()
176
+ total_time = time.time() - start_time
177
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
178
+ print('{} Total time: {} ({:.4f} s / it)'.format(
179
+ header, total_time_str, total_time / len(iterable)))
180
+
181
+
182
+ class AttrDict(dict):
183
+ def __init__(self, *args, **kwargs):
184
+ super(AttrDict, self).__init__(*args, **kwargs)
185
+ self.__dict__ = self
186
+
187
+
188
+ def compute_acc(logits, label, reduction='mean'):
189
+ ret = (torch.argmax(logits, dim=1) == label).float()
190
+ if reduction == 'none':
191
+ return ret.detach()
192
+ elif reduction == 'mean':
193
+ return ret.mean().item()
194
+
195
+ def compute_n_params(model, return_str=True):
196
+ tot = 0
197
+ for p in model.parameters():
198
+ w = 1
199
+ for x in p.shape:
200
+ w *= x
201
+ tot += w
202
+ if return_str:
203
+ if tot >= 1e6:
204
+ return '{:.1f}M'.format(tot / 1e6)
205
+ else:
206
+ return '{:.1f}K'.format(tot / 1e3)
207
+ else:
208
+ return tot
209
+
210
+ def setup_for_distributed(is_master):
211
+ """
212
+ This function disables printing when not in master process
213
+ """
214
+ import builtins as __builtin__
215
+ builtin_print = __builtin__.print
216
+
217
+ def print(*args, **kwargs):
218
+ force = kwargs.pop('force', False)
219
+ if is_master or force:
220
+ builtin_print(*args, **kwargs)
221
+
222
+ __builtin__.print = print
223
+
224
+
225
+ def is_dist_avail_and_initialized():
226
+ if not dist.is_available():
227
+ return False
228
+ if not dist.is_initialized():
229
+ return False
230
+ return True
231
+
232
+
233
+ def get_world_size():
234
+ if not is_dist_avail_and_initialized():
235
+ return 1
236
+ return dist.get_world_size()
237
+
238
+
239
+ def get_rank():
240
+ if not is_dist_avail_and_initialized():
241
+ return 0
242
+ return dist.get_rank()
243
+
244
+
245
+ def is_main_process():
246
+ return get_rank() == 0
247
+
248
+
249
+ def save_on_master(*args, **kwargs):
250
+ if is_main_process():
251
+ torch.save(*args, **kwargs)
252
+
253
+
254
+ def init_distributed_mode(args):
255
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
256
+ args.rank = int(os.environ["RANK"])
257
+ args.world_size = int(os.environ['WORLD_SIZE'])
258
+ args.gpu = int(os.environ['LOCAL_RANK'])
259
+ elif 'SLURM_PROCID' in os.environ:
260
+ args.rank = int(os.environ['SLURM_PROCID'])
261
+ args.gpu = args.rank % torch.cuda.device_count()
262
+ else:
263
+ print('Not using distributed mode')
264
+ args.distributed = False
265
+ return
266
+
267
+ args.distributed = True
268
+
269
+ torch.cuda.set_device(args.gpu)
270
+ args.dist_backend = 'nccl'
271
+ print('| distributed init (rank {}, word {}): {}'.format(
272
+ args.rank, args.world_size, args.dist_url), flush=True)
273
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
274
+ world_size=args.world_size, rank=args.rank)
275
+ torch.distributed.barrier()
276
+ setup_for_distributed(args.rank == 0)
277
+
278
+
walk.jpeg ADDED