xmutly commited on
Commit
317bfc1
·
verified ·
1 Parent(s): 5d3e6fb

Upload 13 files

Browse files
ckps/model_slots_step_300000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22dbaabe6af89f9d67f8127c856c0cbf95a0ef2447b499d7eacefc3e258b29e1
3
+ size 54346897
output_slots/none.txt ADDED
File without changes
train/adversarial_training_clip_with_object_token.py ADDED
@@ -0,0 +1,599 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from train.datasets import COCOFlickrDataset, ImageNetDataset
4
+ from CLIP_eval.eval_utils import load_clip_model
5
+
6
+ sys.path.append("open_flamingo")
7
+ import os
8
+ import shutil
9
+ import time
10
+ import string
11
+ import random
12
+
13
+ import numpy as np
14
+ import open_clip
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import DataLoader
18
+ from training.scheduler import cosine_lr
19
+ from torchvision import transforms
20
+ from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
21
+ from train.pgd_train import pgd
22
+ from train.apgd_train import apgd_train as apgd
23
+ import wandb
24
+ from train.utils import init_wandb, AverageMeter
25
+ from train.sam_data import SamData
26
+ from open_flamingo.eval.models.utils import unwrap_model
27
+ from train.utils import str2bool
28
+
29
+ import argparse
30
+
31
+ from slots.DINOSAUR import DINOSAURpp
32
+ import matplotlib.pyplot as plt
33
+ from einops import rearrange, repeat
34
+
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32')
37
+ parser.add_argument('--pretrained', type=str, default='openai')
38
+ parser.add_argument('--dataset', type=str, default='imagenet')
39
+ parser.add_argument('--template', type=str, default='std')
40
+ parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory')
41
+ parser.add_argument('--output_normalize', type=str2bool, default=False, help='Whether the embedding is normalized')
42
+ parser.add_argument('--start_step', type=int, default=0, help='Start step for training')
43
+ parser.add_argument('--optimizer_state', type=str, default='', help='Optimizer state file path')
44
+ parser.add_argument('--steps', type=int, default=20000, help='Number of training steps')
45
+ parser.add_argument('--warmup', type=int, default=14000, help='Warmup steps')
46
+ parser.add_argument('--batch_size', type=int, default=256)
47
+ parser.add_argument('--loss', type=str, default='l2', help='ce, l2')
48
+ parser.add_argument('--loss_clean', type=str, default='none', help='ce, l2')
49
+ parser.add_argument('--clean_weight', type=float, default=0., help='Weight for clean loss')
50
+ parser.add_argument('--trades', type=str2bool, default=False, help='Use TRADES')
51
+ parser.add_argument('--opt', type=str, default='adamw', help='Optimizer type; sgd, adamw')
52
+ parser.add_argument('--momentum_sgd', type=float, default=0.9, help='Momentum for SGD optimizer')
53
+ parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
54
+ parser.add_argument('--wd', type=float, default=1e-4, help='Weight decay')
55
+ parser.add_argument('--attack', type=str, default='apgd', help='Adversarial attack type')
56
+ parser.add_argument('--inner_loss', type=str, default='l2', help='Inner loss function for adversarial training')
57
+ parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation')
58
+ parser.add_argument('--eps', type=float, default=4, help='Epsilon for adversarial perturbation')
59
+ parser.add_argument('--iterations_adv', type=int, default=10, help='Iterations for adversarial attack')
60
+ parser.add_argument('--stepsize_adv', type=float, default=1., help='Step size for adversarial attack (no effect for apgd)')
61
+ parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging')
62
+ parser.add_argument('--experiment_name', type=str, default='')
63
+ parser.add_argument('--overwrite', type=str2bool, default=False, help='Overwrite existing directory')
64
+ parser.add_argument('--log_freq', type=int, default=1, help='Logging frequency')
65
+ parser.add_argument('--eval_freq', type=int, default=50, help='Evaluation frequency')
66
+ parser.add_argument('--output_dir', type=str, default='', help='Output directory')
67
+ parser.add_argument('--save_checkpoints', type=str2bool, default=True, help='Save 10 training checkpoints')
68
+ parser.add_argument('--devices', type=str, default='', help='Device IDs for CUDA')
69
+
70
+
71
+ ######################################### For object-centric relation reasoning add ###########################
72
+ parser.add_argument('--slots_ckp', type=str, default='/home/tly/RobustVLM/output_slots/ViT-L-14_openai_imagenet_l2_imagenet_SLOTS_NbrnT/checkpoints/fallback_390200.pt', help='slots model ckp root directory')
73
+
74
+
75
+ def main(args):
76
+ # setup wandb
77
+ if args.wandb:
78
+ init_wandb(
79
+ project_name='clip-finetune',
80
+ model_name=args.finetuned_model_name,
81
+ config=vars(args)
82
+ )
83
+ else:
84
+ wandb.init(mode='disabled')
85
+
86
+ # print args
87
+ print(f"Arguments:\n{'-' * 20}")
88
+ for arg, value in vars(args).items():
89
+ print(f"{arg}: {value}")
90
+ print(f"{'-' * 20}")
91
+
92
+ # setup dirs
93
+ if args.overwrite:
94
+ shutil.rmtree(args.output_dir, ignore_errors=True)
95
+ os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=False)
96
+
97
+ # write args to file
98
+ with open(os.path.join(args.output_dir, 'args.txt'), 'w') as f:
99
+ f.write(str(args))
100
+
101
+ main_device = 0
102
+ # get models
103
+ model_orig, _, image_processor = open_clip.create_model_and_transforms(
104
+ args.clip_model_name, pretrained='openai' # 可选 output_tokens=True,返回token + patches
105
+ )
106
+ if args.optimizer_state != '':
107
+ assert args.start_step > 0
108
+ assert str(args.start_step) in args.optimizer_state
109
+ assert args.pretrained in ['', 'none']
110
+ args.pretrained = args.optimizer_state.replace('_opt', '')
111
+ model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)
112
+
113
+ # Remove the Normalize transform by creating a new Compose object
114
+ preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
115
+ normalize = image_processor.transforms[-1]
116
+ del image_processor
117
+ print(f'[preprocessor_without_normalize] {preprocessor_without_normalize}')
118
+ print(f'[normalize] {normalize}')
119
+ # preprocessor_without_normalize contains following transforms:
120
+ # - Resize(size=224, interpolation=bicubic, max_size=None, antialias=warn)
121
+ # - CenterCrop(size=(224, 224))
122
+ # - ToTensor()
123
+ # normalize:
124
+ # Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
125
+
126
+
127
+ ####################################################### get slot-attention model #########################################################
128
+ cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
129
+ model_slots = DINOSAURpp(cfg_dict)
130
+ proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
131
+
132
+
133
+ # get data
134
+ if args.dataset == 'imagenet':
135
+ dataset = ImageNetDataset(
136
+ root=args.imagenet_root + '/train',
137
+ transform=preprocessor_without_normalize,
138
+ )
139
+
140
+ elif args.dataset == 'segment_anything':
141
+ dataset = SamData('/data/naman_deep_singh/datasets/newSAM', transform=preprocessor_without_normalize)
142
+
143
+ print(dataset.__len__())
144
+ elif args.dataset == 'coco':
145
+ if os.path.exists('/mnt/datasets/coco'):
146
+ image_dir_path = '/mnt/datasets/coco/train2017'
147
+ annotations_path = '/mnt/datasets/coco/annotations/captions_train2017.json'
148
+ elif os.path.exists('/mnt/lustre'):
149
+ image_dir_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/train2017'
150
+ annotations_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/annotations/captions_train2017.json'
151
+ else:
152
+ raise ValueError('COCO dataset not found')
153
+ dataset = COCOFlickrDataset(
154
+ image_dir_path=image_dir_path,
155
+ annotations_path=annotations_path,
156
+ transform=preprocessor_without_normalize
157
+ )
158
+ dataset_eval = ImageNetDataset(
159
+ root=args.imagenet_root + '/val',
160
+ transform=preprocessor_without_normalize,
161
+ )
162
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
163
+ dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
164
+
165
+ # Get text label embeddings of all ImageNet classes
166
+ if args.template == 'std':
167
+ template = 'This is a photo of a {}'
168
+ elif args.template == 'blurry':
169
+ template = 'This is a blurry photo of a {}'
170
+ else:
171
+ raise ValueError(f'Unknown template: {args.template}')
172
+ print(f'template: {template}')
173
+ texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
174
+ text_tokens = open_clip.tokenize(texts)
175
+ model_orig.to(main_device)
176
+ with torch.no_grad():
177
+ embedding_text_labels_norm = []
178
+ for el in (text_tokens[:500], text_tokens[500:]):
179
+ # we need to split the text tokens into two batches because otherwise we run out of memory
180
+ # note that we are accessing the model directly here, not the CustomModel wrapper
181
+ # thus its always normalizing the text embeddings
182
+ embedding_text_labels_norm.append(
183
+ model_orig.encode_text(el.to(main_device), normalize=True).detach().cpu()
184
+ )
185
+ embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device)
186
+ assert torch.allclose(
187
+ F.normalize(embedding_text_labels_norm, dim=0),
188
+ embedding_text_labels_norm
189
+ )
190
+ if args.clip_model_name == 'ViT-B-32':
191
+ assert embedding_text_labels_norm.shape == (512, 1000), embedding_text_labels_norm.shape
192
+ elif args.clip_model_name in ('ViT-L-14', 'ViT-L-14-336'):
193
+ assert embedding_text_labels_norm.shape == (768, 1000), embedding_text_labels_norm.shape
194
+ else:
195
+ raise ValueError(f'Unknown model: {args.clip_model_name}')
196
+
197
+ model_orig.cpu()
198
+ model_orig = ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize)
199
+ if num_gpus > 1:
200
+ model_orig = torch.nn.DataParallel(model_orig)
201
+ model_orig.cuda()
202
+
203
+ model = ClipVisionModel(model=model.visual, args=args, normalize=normalize)
204
+ if num_gpus > 1:
205
+ model = torch.nn.DataParallel(model)
206
+ model.cuda()
207
+
208
+ ####################################################### get slot-attention model #########################################################
209
+ model_slots = model_slots
210
+ if num_gpus > 1:
211
+ model_slots = torch.nn.DataParallel(model_slots)
212
+ proj_head = torch.nn.DataParallel(proj_head)
213
+ model_slots.cuda()
214
+ proj_head.cuda()
215
+
216
+ # set optimizer (all params have requires_grad=True)
217
+ params = unwrap_model(model).model.parameters()
218
+ params_head = unwrap_model(proj_head).parameters()
219
+
220
+ if args.opt == 'adamw':
221
+ optimizer = torch.optim.AdamW(
222
+ [{'params': params},
223
+ {'params': params_head}], lr=args.lr, weight_decay=args.wd)
224
+ elif args.opt == 'sgd':
225
+ optimizer = torch.optim.SGD(
226
+ params,
227
+ lr=args.lr,
228
+ momentum=args.momentum_sgd,
229
+ weight_decay=args.wd
230
+ )
231
+ else:
232
+ raise ValueError(f'Optimizer {args.optimizer} not supported.')
233
+ if args.optimizer_state != '':
234
+ optimizer.load_state_dict(torch.load(args.optimizer_state))
235
+
236
+ # set scheduler
237
+ scheduler = cosine_lr(optimizer, args.lr, args.warmup, args.steps)
238
+
239
+ # compute amount of epochs
240
+ total_epochs = args.steps / len(dataloader)
241
+ print(f'train for {total_epochs} epochs')
242
+ args.total_epochs = total_epochs
243
+
244
+ # finetune
245
+ step_total = args.start_step
246
+ epoch = 0
247
+ while step_total < args.steps:
248
+ step_total = train_one_epoch(
249
+ step_total,
250
+ model=model,
251
+ model_orig=model_orig,
252
+ model_slots=model_slots,
253
+ proj_head=proj_head,
254
+ dataloader=dataloader,
255
+ dataloader_eval=dataloader_eval,
256
+ optimizer=optimizer,
257
+ scheduler=scheduler,
258
+ embedding_text_labels_norm=embedding_text_labels_norm,
259
+ normalize=normalize,
260
+ args=args,
261
+ epoch=epoch
262
+ )
263
+ print(f'Epoch {epoch} done.')
264
+ epoch += 1
265
+
266
+ # save final model
267
+ torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/final.pt')
268
+ torch.save(unwrap_model(proj_head).model.state_dict(), f'{args.output_dir}/checkpoints/final_proj_head.pt')
269
+
270
+ torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
271
+
272
+ if args.output_dir.endswith('_temp'):
273
+ # rename temp dir to final dir
274
+ os.rename(args.output_dir, args.output_dir[:-5])
275
+
276
+ class ClipVisionModel(torch.nn.Module):
277
+ def __init__(self, model, args, normalize):
278
+ super().__init__()
279
+ self.model = model
280
+ self.args = args
281
+ self.normalize = normalize
282
+
283
+ def forward(self, vision, output_normalize, return_all_blocks=True, need_OT=False, object_token=None):
284
+ vision = self.normalize(vision)
285
+ embedding, patches = self.model(vision, return_all_blocks=return_all_blocks, need_OT=need_OT, object_token=object_token)
286
+ if output_normalize:
287
+ embedding = F.normalize(embedding, dim=-1)
288
+ return embedding, patches
289
+
290
+
291
+
292
+
293
+
294
+ class ComputeLossWrapper:
295
+ def __init__(self, embedding_orig, embedding_text_labels_norm, reduction='mean', loss=None,
296
+ logit_scale=100.):
297
+ self.embedding_orig = embedding_orig
298
+ self.embedding_text_labels_norm = embedding_text_labels_norm
299
+ self.reduction = reduction
300
+ self.loss_str = loss
301
+ self.logit_scale = logit_scale
302
+
303
+ def __call__(self, embedding, targets):
304
+ return compute_loss(
305
+ loss_str=self.loss_str, embedding=embedding, targets=targets,
306
+ embedding_orig=self.embedding_orig, logit_scale=self.logit_scale,
307
+ embedding_text_labels_norm=self.embedding_text_labels_norm, reduction=self.reduction
308
+ )
309
+
310
+ def train_one_epoch(
311
+ step_total, model, model_orig, model_slots, proj_head, dataloader, optimizer, scheduler, normalize,
312
+ embedding_text_labels_norm, args, epoch, dataloader_eval=None
313
+ ):
314
+ model_orig.eval()
315
+ model.train()
316
+ model_slots.eval()
317
+ proj_head.train()
318
+
319
+ loss_meter = AverageMeter('loss')
320
+ cos_sim_meter = AverageMeter('cos-sim')
321
+ acc_meter = AverageMeter('acc')
322
+ racc_meter = AverageMeter('racc')
323
+
324
+ epoch_start_time = time.time()
325
+ for i, (data, targets) in enumerate(dataloader):
326
+ is_classification = isinstance(targets, torch.Tensor)
327
+ data = data.cuda()
328
+ n_samples = data.shape[0]
329
+ if is_classification:
330
+ targets = targets.cuda()
331
+
332
+ with torch.no_grad():
333
+ embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize)
334
+ reconstruction, slots, masks, x_dinov2 = model_slots(patches_orig) # (B, token, 768)
335
+
336
+ object_token = proj_head(slots)
337
+
338
+ # loss for the attack
339
+ loss_inner_wrapper = ComputeLossWrapper(
340
+ embedding_orig, embedding_text_labels_norm,
341
+ reduction='none' if args.attack == 'apgd' else 'mean', loss=args.inner_loss,
342
+ logit_scale=100.
343
+ )
344
+ model.eval()
345
+
346
+ if args.attack == 'pgd':
347
+ data_adv = pgd(
348
+ forward=model,
349
+ loss_fn=loss_inner_wrapper,
350
+ data_clean=data,
351
+ targets=targets,
352
+ norm=args.norm,
353
+ eps=args.eps,
354
+ iterations=args.iterations_adv,
355
+ stepsize=args.stepsize_adv,
356
+ output_normalize=args.output_normalize,
357
+ perturbation=torch.zeros_like(data).uniform_(-args.eps, args.eps).requires_grad_(True),
358
+ mode='max',
359
+ verbose=False,
360
+ need_OT = False
361
+ )
362
+ elif args.attack == 'apgd':
363
+ # apgd currently always applies output normalization
364
+ data_adv = apgd(
365
+ model=model,
366
+ loss_fn=loss_inner_wrapper,
367
+ x=data,
368
+ y=targets,
369
+ norm=args.norm,
370
+ eps=args.eps,
371
+ n_iter=args.iterations_adv,
372
+ verbose=True
373
+ )
374
+ elif args.attack == 'none':
375
+ data_adv = data
376
+
377
+ del loss_inner_wrapper
378
+ model.train()
379
+
380
+ embedding_clean, patches_clean = model(data, output_normalize=args.output_normalize)
381
+ if args.clean_weight > 0.:
382
+ loss_clean = compute_loss(
383
+ loss_str=args.loss_clean, embedding=embedding_clean, targets=targets,
384
+ embedding_orig=embedding_orig, logit_scale=100., embedding_text_labels_norm=None
385
+ )
386
+ else:
387
+ loss_clean = 0.
388
+
389
+ embedding_adv, patches_adv = model(data_adv, output_normalize=args.output_normalize, need_OT=True, object_token=object_token)
390
+ del data, data_adv
391
+
392
+ if args.trades:
393
+ embedding_clean_no_grad = embedding_clean.detach().clone()
394
+ embedding_orig.cpu()
395
+
396
+ loss = compute_loss(
397
+ loss_str=args.loss, embedding=embedding_adv, targets=targets,
398
+ embedding_orig=embedding_orig if not args.trades else embedding_clean_no_grad,
399
+ logit_scale=100., embedding_text_labels_norm=embedding_text_labels_norm
400
+ )
401
+ loss_total = args.clean_weight * loss_clean + (1 - args.clean_weight) * loss
402
+ loss_total.backward()
403
+ optimizer.step()
404
+ optimizer.zero_grad()
405
+ step_total += 1
406
+ scheduler(step_total)
407
+
408
+ with torch.no_grad():
409
+ # only for logging
410
+ embedding_orig.cuda()
411
+ cos_sim_clean = F.cosine_similarity(embedding_clean, embedding_orig, dim=1).mean()
412
+ cos_sim = F.cosine_similarity(embedding_adv, embedding_orig, dim=1).mean()
413
+ if is_classification:
414
+ logits_adv = embedding_adv @ embedding_text_labels_norm
415
+ racc = compute_acc(logits_adv, targets)
416
+ embedding_clean_norm = F.normalize(embedding_clean, dim=1)
417
+ logits_clean = embedding_clean_norm @ embedding_text_labels_norm
418
+ acc = compute_acc(logits_clean, targets)
419
+ acc_meter.update(acc, n_samples)
420
+ racc_meter.update(racc, n_samples)
421
+ del embedding_clean_norm, embedding_clean
422
+ else:
423
+ acc = None
424
+ racc = None
425
+
426
+ loss_meter.update(loss.item(), n_samples)
427
+ cos_sim_meter.update(cos_sim.item(), n_samples)
428
+
429
+ eval_logs = dict()
430
+ if (step_total-1) % args.eval_freq == 0:
431
+ # we compute acc and racc (against supervised apgd) on validation data
432
+ model.eval()
433
+ data_eval, targets_eval = next(iter(dataloader_eval))
434
+ data_eval, targets_eval = data_eval.cuda(), targets_eval.cuda()
435
+ loss_eval_wrapper = ComputeLossWrapper(
436
+ embedding_orig=None, embedding_text_labels_norm=embedding_text_labels_norm,
437
+ reduction='none', loss='ce', logit_scale=100.
438
+ )
439
+ data_eval_adv = apgd(
440
+ model=model,
441
+ loss_fn=loss_eval_wrapper,
442
+ x=data_eval,
443
+ y=targets_eval,
444
+ norm=args.norm,
445
+ eps=args.eps,
446
+ n_iter=50,
447
+ initial_stepsize=0.05 * args.eps if args.clean_weight > 0 else None,
448
+ verbose=False
449
+ )
450
+ with torch.no_grad():
451
+ embedding_adv_eval_norm, patches_adv_eval_norm = model(data_eval_adv, output_normalize=True) # we set output_normalize to True
452
+ logits_eval_adv = embedding_adv_eval_norm @ embedding_text_labels_norm
453
+ racc_eval = compute_acc(logits_eval_adv, targets_eval)
454
+ embedding_eval_norm, patches_adv_eval_norm = model(data_eval, output_normalize=True)
455
+ logits_eval = embedding_eval_norm @ embedding_text_labels_norm
456
+ acc_eval = compute_acc(logits_eval, targets_eval)
457
+ # note we compute the cosine sim between clean and adv embedding,
458
+ # not between orig and adv embedding as for training
459
+ cos_sim_eval = F.cosine_similarity(embedding_adv_eval_norm, embedding_eval_norm, dim=1).mean()
460
+ eval_logs['eval/racc'] = racc_eval
461
+ eval_logs['eval/acc'] = acc_eval
462
+ eval_logs['eval/cos-sim'] = cos_sim_eval
463
+ print(f'[eval-acc] {acc_eval:.2f} [eval-racc] {racc_eval:.2f} [eval-cos-sim] {cos_sim_eval:.3f}')
464
+ model.train()
465
+ del data_eval_adv, data_eval, targets_eval, embedding_adv_eval_norm, logits_eval_adv, embedding_eval_norm, logits_eval
466
+
467
+ lr_ = optimizer.param_groups[0].get('lr')
468
+ if (step_total-1) % args.log_freq == 0:
469
+ log_str = f'[step] {step_total} [lr] {lr_:.6f} [loss] {loss.item():.6f} [cos-sim] {cos_sim.item():.3f}'
470
+ if is_classification:
471
+ log_str += f' [acc] {acc:.2f} [racc] {racc:.2f}'
472
+ print(log_str)
473
+ log_data = {
474
+ 'step': step_total,
475
+ 'lr': lr_,
476
+ 'loss': loss.item(),
477
+ 'loss-total': loss_total.item(),
478
+ 'cos-sim-clean': cos_sim_clean.item(),
479
+ 'cos-sim': cos_sim.item(),
480
+ 'acc': acc,
481
+ 'racc': racc,
482
+ 'avg/loss': loss_meter.avg,
483
+ 'avg/cos-sim': cos_sim_meter.avg,
484
+ 'avg/acc': acc_meter.avg,
485
+ 'avg/racc': racc_meter.avg,
486
+ }
487
+ log_data.update(eval_logs)
488
+ if (step_total-1) % (args.log_freq * 10) == 0:
489
+ # compute expected average epoch time in hours
490
+ batch_average_time = (time.time() - epoch_start_time) / (i + 1) / (60**2)
491
+ epoch_average_time = batch_average_time * len(dataloader)
492
+ this_epoch_remaining = epoch_average_time - \
493
+ (time.time() - epoch_start_time) / 60**2
494
+ total_remaining = epoch_average_time * (args.total_epochs - epoch - i / len(dataloader))
495
+ print(f'[epoch average time] {epoch_average_time:.2f} [this epoch remaining] '
496
+ f'{this_epoch_remaining:.2f} [total remaining] {total_remaining:.2f}')
497
+
498
+ log_data.update({
499
+ 'time/total-remaining': total_remaining,
500
+ 'time/this-epoch-remaining': this_epoch_remaining,
501
+ 'time/epoch-average-time': epoch_average_time,
502
+ 'time/batch-average-time': batch_average_time,
503
+ 'other/epoch': epoch + i / len(dataloader),
504
+ })
505
+ wandb.log(log_data)
506
+
507
+ # save 10 models over the course of training
508
+ if args.save_checkpoints and (step_total % (args.steps // 10) == 0):
509
+ # save model and optimizer state_dict
510
+ torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
511
+ torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
512
+ # every 200 steps, save a fallback model, which gets overwritten
513
+ if step_total % 200 == 0:
514
+ torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
515
+ torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
516
+ # remove old fallback models
517
+ for file in os.listdir(f'{args.output_dir}/checkpoints'):
518
+ if file.startswith('fallback') and not str(step_total) in file:
519
+ os.remove(f'{args.output_dir}/checkpoints/{file}')
520
+
521
+ if step_total >= args.steps:
522
+ break
523
+
524
+ torch.cuda.empty_cache()
525
+ return step_total
526
+
527
+
528
+ @torch.no_grad()
529
+ def compute_acc(logits, targets):
530
+ preds_clean = logits.max(dim=1)[1].detach()
531
+ acc = (preds_clean.eq(targets).sum() / targets.shape[0]).item() * 100
532
+ return acc
533
+
534
+
535
+ def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale,
536
+ embedding_text_labels_norm=None, reduction='mean'):
537
+ if loss_str == 'l2':
538
+ loss = l2(out=embedding, targets=embedding_orig, reduction=reduction)
539
+ elif loss_str == 'ce':
540
+ loss = ce(
541
+ out=embedding @ (logit_scale * embedding_text_labels_norm),
542
+ targets=targets,
543
+ reduction=reduction
544
+ )
545
+ else:
546
+ raise ValueError(f'loss {loss_str} not supported')
547
+ return loss
548
+
549
+ def l2(out, targets, reduction='none'):
550
+ # squared l2 - it does not divide by the latent dimension
551
+ # should have shape (batch_size, embedding_size)
552
+ assert out.shape == targets.shape, f'{out.shape} != {targets.shape}'
553
+ assert out.shape[0] > 1
554
+ # Compute the element-wise squared error
555
+ squared_error_batch = F.mse_loss(out, targets, reduction='none')
556
+ if reduction == 'mean':
557
+ squared_error_batch = torch.mean(squared_error_batch.sum(dim=1))
558
+ else:
559
+ squared_error_batch = squared_error_batch.sum(dim=1)
560
+ assert squared_error_batch.shape == (out.shape[0],), f'{squared_error_batch.shape} != {(out.shape[0],)}'
561
+ return squared_error_batch
562
+
563
+ def ce(out, targets, reduction='mean'):
564
+ # out = logits
565
+ assert out.shape[0] == targets.shape[0], (out.shape, targets.shape)
566
+ assert out.shape[0] > 1
567
+
568
+ return F.cross_entropy(out, targets, reduction=reduction)
569
+
570
+ if __name__ == '__main__':
571
+ # set seeds
572
+ torch.manual_seed(0)
573
+ np.random.seed(0)
574
+
575
+ # Parse command-line arguments
576
+ args = parser.parse_args()
577
+ args.eps /= 255
578
+ args.stepsize_adv /= 255
579
+ # make sure there is no string in args that should be a bool
580
+ assert not any([isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values()]), f'args contains a string that should be a bool: {args}'
581
+ assert args.eval_freq % args.log_freq == 0, 'eval_freq must be a multiple of log_freq'
582
+
583
+ if args.devices != '':
584
+ # set cuda visible devices
585
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
586
+
587
+ num_gpus = torch.cuda.device_count()
588
+ if num_gpus > 1:
589
+ print(f'Number of GPUs available: {num_gpus}')
590
+ else:
591
+ print('No multiple GPUs available.')
592
+
593
+ # set model name and output dir
594
+ random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=5))
595
+ args.finetuned_model_name = f'{args.clip_model_name}_{args.pretrained}_{args.dataset}_{args.loss}_{args.dataset}_{args.experiment_name}_{random_str}'
596
+ args.finetuned_model_name = args.finetuned_model_name.replace('/', '_')
597
+ args.output_dir = os.path.join(args.output_dir, args.finetuned_model_name)
598
+ # run
599
+ main(args)
train/apgd_train.py CHANGED
@@ -178,7 +178,7 @@ def apgd_train(model, x, y, norm, eps, n_iter=10, use_rs=False, loss_fn=None,
178
  # grad = torch.zeros_like(x)
179
  # for _ in range(self.eot_iter)
180
  # with torch.enable_grad()
181
- logits = model(x_adv, output_normalize=True)
182
  loss_indiv = loss_fn(logits, y)
183
  loss = loss_indiv.sum()
184
  # grad += torch.autograd.grad(loss, [x_adv])[0].detach()
@@ -285,7 +285,7 @@ def apgd_train(model, x, y, norm, eps, n_iter=10, use_rs=False, loss_fn=None,
285
  # grad = torch.zeros_like(x)
286
  # for _ in range(self.eot_iter)
287
  # with torch.enable_grad()
288
- logits = model(x_adv, output_normalize=True)
289
  loss_indiv = loss_fn(logits, y)
290
  loss = loss_indiv.sum()
291
 
 
178
  # grad = torch.zeros_like(x)
179
  # for _ in range(self.eot_iter)
180
  # with torch.enable_grad()
181
+ logits, patch = model(x_adv, output_normalize=True)
182
  loss_indiv = loss_fn(logits, y)
183
  loss = loss_indiv.sum()
184
  # grad += torch.autograd.grad(loss, [x_adv])[0].detach()
 
285
  # grad = torch.zeros_like(x)
286
  # for _ in range(self.eot_iter)
287
  # with torch.enable_grad()
288
+ logits, patch = model(x_adv, output_normalize=True)
289
  loss_indiv = loss_fn(logits, y)
290
  loss = loss_indiv.sum()
291
 
train/pgd_train.py CHANGED
@@ -15,7 +15,8 @@ def pgd(
15
  perturbation=None,
16
  mode='min',
17
  momentum=0.9,
18
- verbose=False
 
19
  ):
20
  """
21
  Minimize or maximize given loss
@@ -29,7 +30,7 @@ def pgd(
29
  for i in range(iterations):
30
  perturbation.requires_grad = True
31
  with torch.enable_grad():
32
- out = forward(data_clean + perturbation, output_normalize=output_normalize)
33
  loss = loss_fn(out, targets)
34
  if verbose:
35
  print(f'[{i}] {loss.item():.5f}')
 
15
  perturbation=None,
16
  mode='min',
17
  momentum=0.9,
18
+ verbose=False,
19
+ need_OT=False
20
  ):
21
  """
22
  Minimize or maximize given loss
 
30
  for i in range(iterations):
31
  perturbation.requires_grad = True
32
  with torch.enable_grad():
33
+ out, patch_out = forward(data_clean + perturbation, output_normalize=output_normalize, need_OT=need_OT)
34
  loss = loss_fn(out, targets)
35
  if verbose:
36
  print(f'[{i}] {loss.item():.5f}')
train/training_clip+dinov2_slots.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+
4
+ from train.datasets import COCOFlickrDataset, ImageNetDataset
5
+ from CLIP_eval.eval_utils import load_clip_model
6
+
7
+ sys.path.append("open_flamingo")
8
+ import os
9
+ import shutil
10
+ import time
11
+ import string
12
+ import random
13
+
14
+ import numpy as np
15
+ import open_clip
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch.utils.data import DataLoader
19
+ from training.scheduler import cosine_lr
20
+ from torchvision import transforms
21
+ from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
22
+ from train.pgd_train import pgd
23
+ from train.apgd_train import apgd_train as apgd
24
+ import wandb
25
+ from train.utils import init_wandb, AverageMeter
26
+ from train.sam_data import SamData
27
+ from open_flamingo.eval.models.utils import unwrap_model
28
+ from train.utils import str2bool
29
+ from torch.hub import load
30
+
31
+
32
+
33
+ RANDOM_SEED = 42 # any random number
34
+ def set_seed(seed):
35
+ random.seed(seed)
36
+ np.random.seed(seed)
37
+ torch.manual_seed(seed) # CPU
38
+ torch.cuda.manual_seed(seed) # GPU
39
+ torch.cuda.manual_seed_all(seed) # All GPU
40
+ os.environ['PYTHONHASHSEED'] = str(seed) # 禁止hash随机化
41
+ torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
42
+ torch.backends.cudnn.benchmark = False # True的话会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。False保证实验结果可复现
43
+ set_seed(RANDOM_SEED)
44
+
45
+ from slots.DINOSAUR import DINOSAURpp
46
+ import matplotlib.pyplot as plt
47
+ from einops import rearrange, repeat
48
+
49
+
50
+
51
+ import argparse
52
+
53
+ parser = argparse.ArgumentParser()
54
+ parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32')
55
+ parser.add_argument('--pretrained', type=str, default='openai')
56
+ parser.add_argument('--dataset', type=str, default='imagenet')
57
+ parser.add_argument('--template', type=str, default='std')
58
+ parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory')
59
+ parser.add_argument('--output_normalize', type=str2bool, default=False, help='Whether the embedding is normalized')
60
+ parser.add_argument('--start_step', type=int, default=0, help='Start step for training')
61
+ parser.add_argument('--optimizer_state', type=str, default='', help='Optimizer state file path')
62
+ parser.add_argument('--steps', type=int, default=20000, help='Number of training steps')
63
+ parser.add_argument('--warmup', type=int, default=14000, help='Warmup steps')
64
+ parser.add_argument('--batch_size', type=int, default=256)
65
+ parser.add_argument('--loss', type=str, default='l2', help='ce, l2')
66
+ parser.add_argument('--loss_clean', type=str, default='none', help='ce, l2')
67
+ parser.add_argument('--clean_weight', type=float, default=0., help='Weight for clean loss')
68
+ parser.add_argument('--trades', type=str2bool, default=False, help='Use TRADES')
69
+ parser.add_argument('--opt', type=str, default='adamw', help='Optimizer type; sgd, adamw')
70
+ parser.add_argument('--momentum_sgd', type=float, default=0.9, help='Momentum for SGD optimizer')
71
+ parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
72
+ parser.add_argument('--wd', type=float, default=1e-4, help='Weight decay')
73
+ parser.add_argument('--attack', type=str, default='apgd', help='Adversarial attack type')
74
+ parser.add_argument('--inner_loss', type=str, default='l2', help='Inner loss function for adversarial training')
75
+ parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation')
76
+ parser.add_argument('--eps', type=float, default=4, help='Epsilon for adversarial perturbation')
77
+ parser.add_argument('--iterations_adv', type=int, default=10, help='Iterations for adversarial attack')
78
+ parser.add_argument('--stepsize_adv', type=float, default=1., help='Step size for adversarial attack (no effect for apgd)')
79
+ parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging')
80
+ parser.add_argument('--experiment_name', type=str, default='')
81
+ parser.add_argument('--overwrite', type=str2bool, default=False, help='Overwrite existing directory')
82
+ parser.add_argument('--log_freq', type=int, default=1, help='Logging frequency')
83
+ parser.add_argument('--eval_freq', type=int, default=50, help='Evaluation frequency')
84
+ parser.add_argument('--output_dir', type=str, default='', help='Output directory')
85
+ parser.add_argument('--save_checkpoints', type=str2bool, default=True, help='Save 10 training checkpoints')
86
+ parser.add_argument('--devices', type=str, default='', help='Device IDs for CUDA')
87
+
88
+
89
+ def main(args):
90
+ # setup wandb
91
+ if args.wandb:
92
+ init_wandb(
93
+ project_name='clip-finetune',
94
+ model_name=args.finetuned_model_name,
95
+ config=vars(args)
96
+ )
97
+ else:
98
+ wandb.init(mode='disabled')
99
+
100
+ # print args
101
+ print(f"Arguments:\n{'-' * 20}")
102
+ for arg, value in vars(args).items():
103
+ print(f"{arg}: {value}")
104
+ print(f"{'-' * 20}")
105
+
106
+ # setup dirs
107
+ if args.overwrite:
108
+ shutil.rmtree(args.output_dir, ignore_errors=True)
109
+ os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=False)
110
+
111
+ # write args to file
112
+ with open(os.path.join(args.output_dir, 'args.txt'), 'w') as f:
113
+ f.write(str(args))
114
+
115
+ main_device = 0
116
+ # get models
117
+ from open_clip.model import CLIPVisionCfg
118
+ CLIPVisionCfg.output_tokens = True
119
+ model_orig, _, image_processor = open_clip.create_model_and_transforms(
120
+ args.clip_model_name, pretrained='openai'#, output_tokens=True # 可选 output_tokens=True,返回token + patches
121
+ )
122
+ # Remove the Normalize transform by creating a new Compose object
123
+ preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
124
+ normalize = image_processor.transforms[-1]
125
+ del image_processor
126
+ print(f'[preprocessor_without_normalize] {preprocessor_without_normalize}')
127
+
128
+ ####################################################### get slot-attention model #########################################################
129
+ cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
130
+ model_slots = DINOSAURpp(cfg_dict)
131
+
132
+ # get data
133
+ if args.dataset == 'imagenet':
134
+ dataset = ImageNetDataset(
135
+ root=args.imagenet_root + '/train',
136
+ transform=preprocessor_without_normalize,
137
+ )
138
+
139
+ elif args.dataset == 'segment_anything':
140
+ dataset = SamData('/data/naman_deep_singh/datasets/newSAM', transform=preprocessor_without_normalize)
141
+
142
+ print(dataset.__len__())
143
+ elif args.dataset == 'coco':
144
+ if os.path.exists('/mnt/datasets/coco'):
145
+ image_dir_path = '/mnt/datasets/coco/train2017'
146
+ annotations_path = '/mnt/datasets/coco/annotations/captions_train2017.json'
147
+ elif os.path.exists('/mnt/lustre'):
148
+ image_dir_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/train2017'
149
+ annotations_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/annotations/captions_train2017.json'
150
+ else:
151
+ raise ValueError('COCO dataset not found')
152
+ dataset = COCOFlickrDataset(
153
+ image_dir_path=image_dir_path,
154
+ annotations_path=annotations_path,
155
+ transform=preprocessor_without_normalize
156
+ )
157
+ dataset_eval = ImageNetDataset(
158
+ root=args.imagenet_root + '/val',
159
+ transform=preprocessor_without_normalize,
160
+ )
161
+ dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
162
+ dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
163
+
164
+ # Get text label embeddings of all ImageNet classes
165
+ if args.template == 'std':
166
+ template = 'This is a photo of a {}'
167
+ elif args.template == 'blurry':
168
+ template = 'This is a blurry photo of a {}'
169
+ else:
170
+ raise ValueError(f'Unknown template: {args.template}')
171
+ print(f'template: {template}')
172
+ texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
173
+ text_tokens = open_clip.tokenize(texts)
174
+ model_orig.to(main_device)
175
+ with torch.no_grad():
176
+ embedding_text_labels_norm = []
177
+ for el in (text_tokens[:500], text_tokens[500:]):
178
+ # we need to split the text tokens into two batches because otherwise we run out of memory
179
+ # note that we are accessing the model directly here, not the CustomModel wrapper
180
+ # thus its always normalizing the text embeddings
181
+ embedding_text_labels_norm.append(
182
+ model_orig.encode_text(el.to(main_device), normalize=True).detach().cpu()
183
+ )
184
+ embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device)
185
+ assert torch.allclose(
186
+ F.normalize(embedding_text_labels_norm, dim=0),
187
+ embedding_text_labels_norm
188
+ )
189
+ if args.clip_model_name == 'ViT-B-32':
190
+ assert embedding_text_labels_norm.shape == (512, 1000), embedding_text_labels_norm.shape
191
+ elif args.clip_model_name in ('ViT-L-14', 'ViT-L-14-336'):
192
+ assert embedding_text_labels_norm.shape == (768, 1000), embedding_text_labels_norm.shape
193
+ else:
194
+ raise ValueError(f'Unknown model: {args.clip_model_name}')
195
+
196
+ model_orig.cpu()
197
+ model_orig = ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize)
198
+ if num_gpus > 1:
199
+ model_orig = torch.nn.DataParallel(model_orig)
200
+ model_orig.cuda()
201
+
202
+ model_slots = model_slots
203
+ if num_gpus > 1:
204
+ model_slots = torch.nn.DataParallel(model_slots)
205
+ model_slots.cuda()
206
+
207
+
208
+
209
+ ####################################################### add dino v2 as target model #####################################################
210
+ target_backbone_name = 'dinov2_vitb14'
211
+ model_dinov2 = load('/home/tly/.cache/torch/hub/facebookresearch_dinov2_main', target_backbone_name, source='local') # /home/ly/.cache/torch/hub/facebookresearch_dinov2_main facebookresearch/dinov2
212
+ # model_dinov2 = load('facebookresearch/dinov2', target_backbone_name) # /home/ly/.cache/torch/hub/facebookresearch_dinov2_main facebookresearch/dinov2
213
+
214
+ proj_head = torch.nn.Linear(1024, 768)
215
+ if num_gpus > 1:
216
+ model_dinov2 = torch.nn.DataParallel(model_dinov2)
217
+ proj_head = torch.nn.DataParallel(proj_head)
218
+ model_dinov2.cuda()
219
+ proj_head.cuda()
220
+
221
+ # set optimizer (all params have requires_grad=True)
222
+ params_slots = unwrap_model(model_slots).parameters()
223
+ params_head = unwrap_model(proj_head).parameters()
224
+
225
+ if args.opt == 'adamw':
226
+ optimizer = torch.optim.AdamW(
227
+ [{'params': params_slots},
228
+ {'params': params_head}], lr=args.lr, weight_decay=args.wd)
229
+ elif args.opt == 'sgd':
230
+ optimizer = torch.optim.SGD(
231
+ [{'params': params_slots},
232
+ {'params': params_head}],
233
+ lr=args.lr,
234
+ momentum=args.momentum_sgd,
235
+ weight_decay=args.wd
236
+ )
237
+ else:
238
+ raise ValueError(f'Optimizer {args.optimizer} not supported.')
239
+ if args.optimizer_state != '':
240
+ optimizer.load_state_dict(torch.load(args.optimizer_state))
241
+
242
+ # set scheduler
243
+ scheduler = cosine_lr(optimizer, args.lr, args.warmup, args.steps)
244
+
245
+ # compute amount of epochs
246
+ total_epochs = args.steps / len(dataloader)
247
+ print(f'train for {total_epochs} epochs')
248
+ args.total_epochs = total_epochs
249
+
250
+ # finetune
251
+ step_total = args.start_step
252
+ epoch = 0
253
+ while step_total < args.steps:
254
+ step_total = train_one_epoch_slots(
255
+ step_total,
256
+ model_slots=model_slots,
257
+ model_orig=model_orig,
258
+ model_dinov2=model_dinov2,
259
+ proj_head=proj_head,
260
+ dataloader=dataloader,
261
+ dataloader_eval=dataloader_eval,
262
+ optimizer=optimizer,
263
+ scheduler=scheduler,
264
+ embedding_text_labels_norm=embedding_text_labels_norm,
265
+ normalize=normalize,
266
+ args=args,
267
+ epoch=epoch
268
+ )
269
+ print(f'Epoch {epoch} done.')
270
+ epoch += 1
271
+
272
+ # save final model
273
+ torch.save(unwrap_model(model_slots).state_dict(), f'{args.output_dir}/checkpoints/final.pt')
274
+ torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
275
+
276
+ if args.output_dir.endswith('_temp'):
277
+ # rename temp dir to final dir
278
+ os.rename(args.output_dir, args.output_dir[:-5])
279
+
280
+ class ClipVisionModel(torch.nn.Module):
281
+ def __init__(self, model, args, normalize):
282
+ super().__init__()
283
+ self.model = model
284
+ self.args = args
285
+ self.normalize = normalize
286
+
287
+ def forward(self, vision, output_normalize, return_all_blocks=False):
288
+ vision = self.normalize(vision)
289
+ embedding, patches = self.model(vision, return_all_blocks=return_all_blocks)
290
+ if output_normalize:
291
+ embedding = F.normalize(embedding, dim=-1)
292
+ return embedding, patches
293
+
294
+
295
+ class ComputeLossWrapper:
296
+ def __init__(self, embedding_orig, embedding_text_labels_norm, reduction='mean', loss=None,
297
+ logit_scale=100.):
298
+ self.embedding_orig = embedding_orig
299
+ self.embedding_text_labels_norm = embedding_text_labels_norm
300
+ self.reduction = reduction
301
+ self.loss_str = loss
302
+ self.logit_scale = logit_scale
303
+
304
+ def __call__(self, embedding, targets):
305
+ return compute_loss(
306
+ loss_str=self.loss_str, embedding=embedding, targets=targets,
307
+ embedding_orig=self.embedding_orig, logit_scale=self.logit_scale,
308
+ embedding_text_labels_norm=self.embedding_text_labels_norm, reduction=self.reduction
309
+ )
310
+
311
+ def mean_flat(x):
312
+ """
313
+ Take the mean over all non-batch dimensions.
314
+ """
315
+ return torch.mean(x, dim=list(range(1, len(x.size()))))
316
+
317
+
318
+ def train_one_epoch_slots(
319
+ step_total, model_slots, model_orig, model_dinov2, proj_head, dataloader, optimizer, scheduler, normalize,
320
+ embedding_text_labels_norm, args, epoch, dataloader_eval=None
321
+ ):
322
+ model_orig.eval()
323
+ model_slots.train()
324
+
325
+ MSEFunc = torch.nn.MSELoss()
326
+
327
+
328
+ loss_meter = AverageMeter('loss')
329
+
330
+
331
+ epoch_start_time = time.time()
332
+ for i, (data, targets) in enumerate(dataloader):
333
+ is_classification = isinstance(targets, torch.Tensor)
334
+ data = data.cuda()
335
+ n_samples = data.shape[0]
336
+ if is_classification:
337
+ targets = targets.cuda()
338
+
339
+ with torch.no_grad():
340
+ embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize, return_all_blocks=True)
341
+ feat_dinov2_dict = model_dinov2(data, is_training=True)
342
+
343
+ if num_gpus > 1:
344
+ patches_orig = model_orig.module.model.ln_pre(patches_orig)
345
+ else:
346
+ patches_orig = model_orig.model.ln_pre(patches_orig)
347
+
348
+
349
+ reconstruction, slots, masks, x_dinov2 = model_slots(patches_orig) # (B, token, 768)
350
+ x_dinov2 = torch.sum(x_dinov2, dim=1)
351
+
352
+ # change target to dino features
353
+ # reconstruction_to_dinov2 = proj_head(reconstruction)
354
+
355
+ b, hw, c = patches_orig.shape
356
+ h, w = int(np.sqrt(hw)), int(np.sqrt(hw))
357
+ k = slots.size(1)
358
+ c_dinov2 = feat_dinov2_dict['x_norm_patchtokens'].size(-1)
359
+
360
+ reconstruction = rearrange(reconstruction, 'b (h w) c -> b c h w', h=h, w=w)
361
+ masks = rearrange(masks, 'b k (h w) -> b k h w', h=h, w=w, k=k)
362
+ patches_orig = rearrange(patches_orig, 'b (h w) c -> b c h w', h=h, w=w)
363
+
364
+
365
+ # loss for the attack
366
+ loss = MSEFunc(reconstruction, patches_orig)
367
+ x_dinov2 = torch.nn.functional.normalize(x_dinov2, dim=-1)
368
+ feat_dinov2 = torch.nn.functional.normalize(feat_dinov2_dict['x_norm_patchtokens'], dim=-1)
369
+ # x_dinov2 = torch.nn.functional.normalize(x_dinov2, dim=-1)
370
+ # feat_dinov2 = torch.nn.functional.normalize(feat_dinov2_dict['x_norm_patchtokens'], dim=-1)
371
+ # loss_dinov2 = MSEFunc(feat_dinov2, x_dinov2)#reconstruction_to_dinov2)
372
+ loss_dinov2 = mean_flat(-(x_dinov2 * feat_dinov2).sum(dim=-1)).mean()
373
+
374
+ loss_total = loss + 0.0001*loss_dinov2
375
+
376
+
377
+ loss_total.backward()
378
+ optimizer.step()
379
+ optimizer.zero_grad()
380
+ step_total += 1
381
+ scheduler(step_total)
382
+
383
+ lr_ = optimizer.param_groups[0].get('lr')
384
+ if (step_total-1) % args.log_freq == 0:
385
+ log_str = f'[step] {step_total} [lr] {lr_:.6f} [loss] {loss.item():.6f} [loss_dinov2] {loss_dinov2.item():.6f}'
386
+ print(log_str)
387
+ log_data = {
388
+ 'step': step_total,
389
+ 'lr': lr_,
390
+ 'loss': loss.item(),
391
+ 'loss_dinov2': loss_dinov2.item(),
392
+ 'loss-total': loss_total.item(),
393
+ 'avg/loss': loss_meter.avg,
394
+ }
395
+ if (step_total-1) % (args.log_freq * 10) == 0:
396
+ # compute expected average epoch time in hours
397
+ batch_average_time = (time.time() - epoch_start_time) / (i + 1) / (60**2)
398
+ epoch_average_time = batch_average_time * len(dataloader)
399
+ this_epoch_remaining = epoch_average_time - \
400
+ (time.time() - epoch_start_time) / 60**2
401
+ total_remaining = epoch_average_time * (args.total_epochs - epoch - i / len(dataloader))
402
+ print(f'[epoch average time] {epoch_average_time:.2f} [this epoch remaining] '
403
+ f'{this_epoch_remaining:.2f} [total remaining] {total_remaining:.2f}')
404
+
405
+ log_data.update({
406
+ 'time/total-remaining': total_remaining,
407
+ 'time/this-epoch-remaining': this_epoch_remaining,
408
+ 'time/epoch-average-time': epoch_average_time,
409
+ 'time/batch-average-time': batch_average_time,
410
+ 'other/epoch': epoch + i / len(dataloader),
411
+ })
412
+ wandb.log(log_data)
413
+
414
+ # save 10 models over the course of training
415
+ if args.save_checkpoints and (step_total % (args.steps // 10) == 0):
416
+ # save model and optimizer state_dict
417
+ torch.save(unwrap_model(model_slots).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
418
+ torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj.pt')
419
+
420
+ torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
421
+ # every 200 steps, save a fallback model, which gets overwritten
422
+ if step_total % 200 == 0:
423
+ torch.save(unwrap_model(model_slots).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
424
+ torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj.pt')
425
+
426
+ torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
427
+ # remove old fallback models
428
+ for file in os.listdir(f'{args.output_dir}/checkpoints'):
429
+ if file.startswith('fallback') and not str(step_total) in file:
430
+ os.remove(f'{args.output_dir}/checkpoints/{file}')
431
+
432
+ ######################################################## Save ori Image and recon Image ########################
433
+ # if epoch % 5 == 0:
434
+ save_pics_path = os.path.join(args.output_dir, 'slots_recons')
435
+
436
+ recon_pic_save_path = os.path.join(save_pics_path, args.dataset)
437
+ os.makedirs(recon_pic_save_path, exist_ok=True)
438
+
439
+ plt.imshow(reconstruction[0, 0].detach().cpu().numpy())
440
+ save_name = 'recon_pic_steps{}.png'.format(step_total)
441
+ plt.savefig(os.path.join(recon_pic_save_path, save_name))
442
+
443
+ plt.imshow(patches_orig[0, 0].detach().cpu().numpy())
444
+ save_name = 'recon_pic_steps{}_feat.png'.format(step_total)
445
+ plt.savefig(os.path.join(recon_pic_save_path, save_name))
446
+
447
+
448
+ plt.imshow(rearrange(x_dinov2, 'b (h w) c_dinov2 -> b c_dinov2 h w', h=h, w=w)[0, 0].detach().cpu().numpy())
449
+ save_name = 'recon_pic_steps{}_dinov2.png'.format(step_total)
450
+ plt.savefig(os.path.join(recon_pic_save_path, save_name))
451
+
452
+ plt.imshow(rearrange(feat_dinov2_dict['x_norm_patchtokens'], 'b (h w) c_dinov2 -> b c_dinov2 h w', h=h, w=w)[0, 0].detach().cpu().numpy())
453
+ save_name = 'recon_pic_steps{}_feat_dinov2.png'.format(step_total)
454
+ plt.savefig(os.path.join(recon_pic_save_path, save_name))
455
+
456
+ plt.imshow(data[0].permute(1, 2, 0).detach().cpu().numpy())
457
+ save_name = 'recon_pic_steps{}_ori.png'.format(step_total)
458
+ plt.savefig(os.path.join(recon_pic_save_path, save_name))
459
+
460
+ plt.close('all')
461
+
462
+ if step_total >= args.steps:
463
+ break
464
+
465
+ # torch.cuda.empty_cache()
466
+
467
+
468
+
469
+ return step_total
470
+
471
+
472
+ @torch.no_grad()
473
+ def compute_acc(logits, targets):
474
+ preds_clean = logits.max(dim=1)[1].detach()
475
+ acc = (preds_clean.eq(targets).sum() / targets.shape[0]).item() * 100
476
+ return acc
477
+
478
+
479
+ def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale,
480
+ embedding_text_labels_norm=None, reduction='mean'):
481
+ if loss_str == 'l2':
482
+ loss = l2(out=embedding, targets=embedding_orig, reduction=reduction)
483
+ elif loss_str == 'ce':
484
+ loss = ce(
485
+ out=embedding @ (logit_scale * embedding_text_labels_norm),
486
+ targets=targets,
487
+ reduction=reduction
488
+ )
489
+ else:
490
+ raise ValueError(f'loss {loss_str} not supported')
491
+ return loss
492
+
493
+ def l2(out, targets, reduction='none'):
494
+ # squared l2 - it does not divide by the latent dimension
495
+ # should have shape (batch_size, embedding_size)
496
+ assert out.shape == targets.shape, f'{out.shape} != {targets.shape}'
497
+ assert out.shape[0] > 1
498
+ # Compute the element-wise squared error
499
+ squared_error_batch = F.mse_loss(out, targets, reduction='none')
500
+ if reduction == 'mean':
501
+ squared_error_batch = torch.mean(squared_error_batch.sum(dim=1))
502
+ else:
503
+ squared_error_batch = squared_error_batch.sum(dim=1)
504
+ assert squared_error_batch.shape == (out.shape[0],), f'{squared_error_batch.shape} != {(out.shape[0],)}'
505
+ return squared_error_batch
506
+
507
+ def ce(out, targets, reduction='mean'):
508
+ # out = logits
509
+ assert out.shape[0] == targets.shape[0], (out.shape, targets.shape)
510
+ assert out.shape[0] > 1
511
+
512
+ return F.cross_entropy(out, targets, reduction=reduction)
513
+
514
+ if __name__ == '__main__':
515
+ # set seeds
516
+ torch.manual_seed(0)
517
+ np.random.seed(0)
518
+
519
+ # Parse command-line arguments
520
+ args = parser.parse_args()
521
+ args.eps /= 255
522
+ args.stepsize_adv /= 255
523
+ # make sure there is no string in args that should be a bool
524
+ assert not any([isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values()]), f'args contains a string that should be a bool: {args}'
525
+ assert args.eval_freq % args.log_freq == 0, 'eval_freq must be a multiple of log_freq'
526
+
527
+ if args.devices != '':
528
+ # set cuda visible devices
529
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
530
+
531
+ num_gpus = torch.cuda.device_count()
532
+ if num_gpus > 1:
533
+ print(f'Number of GPUs available: {num_gpus}')
534
+ else:
535
+ print('No multiple GPUs available.')
536
+
537
+ # set model name and output dir
538
+ random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=5))
539
+ args.finetuned_model_name = f'{args.clip_model_name}_{args.pretrained}_{args.dataset}_{args.loss}_{args.dataset}_{args.experiment_name}_{random_str}'
540
+ args.finetuned_model_name = args.finetuned_model_name.replace('/', '_')
541
+ args.output_dir = os.path.join(args.output_dir, args.finetuned_model_name)
542
+ # run
543
+ main(args)
train/training_clip_slots.py CHANGED
@@ -27,6 +27,18 @@ from train.sam_data import SamData
27
  from open_flamingo.eval.models.utils import unwrap_model
28
  from train.utils import str2bool
29
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  from slots.DINOSAUR import DINOSAURpp
31
  import matplotlib.pyplot as plt
32
  from einops import rearrange, repeat
@@ -393,7 +405,7 @@ def train_one_epoch_slots(
393
  if step_total >= args.steps:
394
  break
395
 
396
- torch.cuda.empty_cache()
397
 
398
 
399
 
 
27
  from open_flamingo.eval.models.utils import unwrap_model
28
  from train.utils import str2bool
29
 
30
+ RANDOM_SEED = 42 # any random number
31
+ def set_seed(seed):
32
+ random.seed(seed)
33
+ np.random.seed(seed)
34
+ torch.manual_seed(seed) # CPU
35
+ torch.cuda.manual_seed(seed) # GPU
36
+ torch.cuda.manual_seed_all(seed) # All GPU
37
+ os.environ['PYTHONHASHSEED'] = str(seed) # 禁止hash随机化
38
+ torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
39
+ torch.backends.cudnn.benchmark = False # True的话会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。False保证实验结果可复现
40
+ set_seed(RANDOM_SEED)
41
+
42
  from slots.DINOSAUR import DINOSAURpp
43
  import matplotlib.pyplot as plt
44
  from einops import rearrange, repeat
 
405
  if step_total >= args.steps:
406
  break
407
 
408
+ # torch.cuda.empty_cache()
409
 
410
 
411