Upload 13 files
Browse files- ckps/model_slots_step_300000.pt +3 -0
- output_slots/none.txt +0 -0
- train/adversarial_training_clip_with_object_token.py +599 -0
- train/apgd_train.py +2 -2
- train/pgd_train.py +3 -2
- train/training_clip+dinov2_slots.py +543 -0
- train/training_clip_slots.py +13 -1
ckps/model_slots_step_300000.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22dbaabe6af89f9d67f8127c856c0cbf95a0ef2447b499d7eacefc3e258b29e1
|
3 |
+
size 54346897
|
output_slots/none.txt
ADDED
File without changes
|
train/adversarial_training_clip_with_object_token.py
ADDED
@@ -0,0 +1,599 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
from train.datasets import COCOFlickrDataset, ImageNetDataset
|
4 |
+
from CLIP_eval.eval_utils import load_clip_model
|
5 |
+
|
6 |
+
sys.path.append("open_flamingo")
|
7 |
+
import os
|
8 |
+
import shutil
|
9 |
+
import time
|
10 |
+
import string
|
11 |
+
import random
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import open_clip
|
15 |
+
import torch
|
16 |
+
import torch.nn.functional as F
|
17 |
+
from torch.utils.data import DataLoader
|
18 |
+
from training.scheduler import cosine_lr
|
19 |
+
from torchvision import transforms
|
20 |
+
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
|
21 |
+
from train.pgd_train import pgd
|
22 |
+
from train.apgd_train import apgd_train as apgd
|
23 |
+
import wandb
|
24 |
+
from train.utils import init_wandb, AverageMeter
|
25 |
+
from train.sam_data import SamData
|
26 |
+
from open_flamingo.eval.models.utils import unwrap_model
|
27 |
+
from train.utils import str2bool
|
28 |
+
|
29 |
+
import argparse
|
30 |
+
|
31 |
+
from slots.DINOSAUR import DINOSAURpp
|
32 |
+
import matplotlib.pyplot as plt
|
33 |
+
from einops import rearrange, repeat
|
34 |
+
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32')
|
37 |
+
parser.add_argument('--pretrained', type=str, default='openai')
|
38 |
+
parser.add_argument('--dataset', type=str, default='imagenet')
|
39 |
+
parser.add_argument('--template', type=str, default='std')
|
40 |
+
parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory')
|
41 |
+
parser.add_argument('--output_normalize', type=str2bool, default=False, help='Whether the embedding is normalized')
|
42 |
+
parser.add_argument('--start_step', type=int, default=0, help='Start step for training')
|
43 |
+
parser.add_argument('--optimizer_state', type=str, default='', help='Optimizer state file path')
|
44 |
+
parser.add_argument('--steps', type=int, default=20000, help='Number of training steps')
|
45 |
+
parser.add_argument('--warmup', type=int, default=14000, help='Warmup steps')
|
46 |
+
parser.add_argument('--batch_size', type=int, default=256)
|
47 |
+
parser.add_argument('--loss', type=str, default='l2', help='ce, l2')
|
48 |
+
parser.add_argument('--loss_clean', type=str, default='none', help='ce, l2')
|
49 |
+
parser.add_argument('--clean_weight', type=float, default=0., help='Weight for clean loss')
|
50 |
+
parser.add_argument('--trades', type=str2bool, default=False, help='Use TRADES')
|
51 |
+
parser.add_argument('--opt', type=str, default='adamw', help='Optimizer type; sgd, adamw')
|
52 |
+
parser.add_argument('--momentum_sgd', type=float, default=0.9, help='Momentum for SGD optimizer')
|
53 |
+
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
|
54 |
+
parser.add_argument('--wd', type=float, default=1e-4, help='Weight decay')
|
55 |
+
parser.add_argument('--attack', type=str, default='apgd', help='Adversarial attack type')
|
56 |
+
parser.add_argument('--inner_loss', type=str, default='l2', help='Inner loss function for adversarial training')
|
57 |
+
parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation')
|
58 |
+
parser.add_argument('--eps', type=float, default=4, help='Epsilon for adversarial perturbation')
|
59 |
+
parser.add_argument('--iterations_adv', type=int, default=10, help='Iterations for adversarial attack')
|
60 |
+
parser.add_argument('--stepsize_adv', type=float, default=1., help='Step size for adversarial attack (no effect for apgd)')
|
61 |
+
parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging')
|
62 |
+
parser.add_argument('--experiment_name', type=str, default='')
|
63 |
+
parser.add_argument('--overwrite', type=str2bool, default=False, help='Overwrite existing directory')
|
64 |
+
parser.add_argument('--log_freq', type=int, default=1, help='Logging frequency')
|
65 |
+
parser.add_argument('--eval_freq', type=int, default=50, help='Evaluation frequency')
|
66 |
+
parser.add_argument('--output_dir', type=str, default='', help='Output directory')
|
67 |
+
parser.add_argument('--save_checkpoints', type=str2bool, default=True, help='Save 10 training checkpoints')
|
68 |
+
parser.add_argument('--devices', type=str, default='', help='Device IDs for CUDA')
|
69 |
+
|
70 |
+
|
71 |
+
######################################### For object-centric relation reasoning add ###########################
|
72 |
+
parser.add_argument('--slots_ckp', type=str, default='/home/tly/RobustVLM/output_slots/ViT-L-14_openai_imagenet_l2_imagenet_SLOTS_NbrnT/checkpoints/fallback_390200.pt', help='slots model ckp root directory')
|
73 |
+
|
74 |
+
|
75 |
+
def main(args):
|
76 |
+
# setup wandb
|
77 |
+
if args.wandb:
|
78 |
+
init_wandb(
|
79 |
+
project_name='clip-finetune',
|
80 |
+
model_name=args.finetuned_model_name,
|
81 |
+
config=vars(args)
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
wandb.init(mode='disabled')
|
85 |
+
|
86 |
+
# print args
|
87 |
+
print(f"Arguments:\n{'-' * 20}")
|
88 |
+
for arg, value in vars(args).items():
|
89 |
+
print(f"{arg}: {value}")
|
90 |
+
print(f"{'-' * 20}")
|
91 |
+
|
92 |
+
# setup dirs
|
93 |
+
if args.overwrite:
|
94 |
+
shutil.rmtree(args.output_dir, ignore_errors=True)
|
95 |
+
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=False)
|
96 |
+
|
97 |
+
# write args to file
|
98 |
+
with open(os.path.join(args.output_dir, 'args.txt'), 'w') as f:
|
99 |
+
f.write(str(args))
|
100 |
+
|
101 |
+
main_device = 0
|
102 |
+
# get models
|
103 |
+
model_orig, _, image_processor = open_clip.create_model_and_transforms(
|
104 |
+
args.clip_model_name, pretrained='openai' # 可选 output_tokens=True,返回token + patches
|
105 |
+
)
|
106 |
+
if args.optimizer_state != '':
|
107 |
+
assert args.start_step > 0
|
108 |
+
assert str(args.start_step) in args.optimizer_state
|
109 |
+
assert args.pretrained in ['', 'none']
|
110 |
+
args.pretrained = args.optimizer_state.replace('_opt', '')
|
111 |
+
model, _, _ = load_clip_model(args.clip_model_name, args.pretrained)
|
112 |
+
|
113 |
+
# Remove the Normalize transform by creating a new Compose object
|
114 |
+
preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
|
115 |
+
normalize = image_processor.transforms[-1]
|
116 |
+
del image_processor
|
117 |
+
print(f'[preprocessor_without_normalize] {preprocessor_without_normalize}')
|
118 |
+
print(f'[normalize] {normalize}')
|
119 |
+
# preprocessor_without_normalize contains following transforms:
|
120 |
+
# - Resize(size=224, interpolation=bicubic, max_size=None, antialias=warn)
|
121 |
+
# - CenterCrop(size=(224, 224))
|
122 |
+
# - ToTensor()
|
123 |
+
# normalize:
|
124 |
+
# Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))
|
125 |
+
|
126 |
+
|
127 |
+
####################################################### get slot-attention model #########################################################
|
128 |
+
cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
|
129 |
+
model_slots = DINOSAURpp(cfg_dict)
|
130 |
+
proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
|
131 |
+
|
132 |
+
|
133 |
+
# get data
|
134 |
+
if args.dataset == 'imagenet':
|
135 |
+
dataset = ImageNetDataset(
|
136 |
+
root=args.imagenet_root + '/train',
|
137 |
+
transform=preprocessor_without_normalize,
|
138 |
+
)
|
139 |
+
|
140 |
+
elif args.dataset == 'segment_anything':
|
141 |
+
dataset = SamData('/data/naman_deep_singh/datasets/newSAM', transform=preprocessor_without_normalize)
|
142 |
+
|
143 |
+
print(dataset.__len__())
|
144 |
+
elif args.dataset == 'coco':
|
145 |
+
if os.path.exists('/mnt/datasets/coco'):
|
146 |
+
image_dir_path = '/mnt/datasets/coco/train2017'
|
147 |
+
annotations_path = '/mnt/datasets/coco/annotations/captions_train2017.json'
|
148 |
+
elif os.path.exists('/mnt/lustre'):
|
149 |
+
image_dir_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/train2017'
|
150 |
+
annotations_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/annotations/captions_train2017.json'
|
151 |
+
else:
|
152 |
+
raise ValueError('COCO dataset not found')
|
153 |
+
dataset = COCOFlickrDataset(
|
154 |
+
image_dir_path=image_dir_path,
|
155 |
+
annotations_path=annotations_path,
|
156 |
+
transform=preprocessor_without_normalize
|
157 |
+
)
|
158 |
+
dataset_eval = ImageNetDataset(
|
159 |
+
root=args.imagenet_root + '/val',
|
160 |
+
transform=preprocessor_without_normalize,
|
161 |
+
)
|
162 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
|
163 |
+
dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
|
164 |
+
|
165 |
+
# Get text label embeddings of all ImageNet classes
|
166 |
+
if args.template == 'std':
|
167 |
+
template = 'This is a photo of a {}'
|
168 |
+
elif args.template == 'blurry':
|
169 |
+
template = 'This is a blurry photo of a {}'
|
170 |
+
else:
|
171 |
+
raise ValueError(f'Unknown template: {args.template}')
|
172 |
+
print(f'template: {template}')
|
173 |
+
texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
|
174 |
+
text_tokens = open_clip.tokenize(texts)
|
175 |
+
model_orig.to(main_device)
|
176 |
+
with torch.no_grad():
|
177 |
+
embedding_text_labels_norm = []
|
178 |
+
for el in (text_tokens[:500], text_tokens[500:]):
|
179 |
+
# we need to split the text tokens into two batches because otherwise we run out of memory
|
180 |
+
# note that we are accessing the model directly here, not the CustomModel wrapper
|
181 |
+
# thus its always normalizing the text embeddings
|
182 |
+
embedding_text_labels_norm.append(
|
183 |
+
model_orig.encode_text(el.to(main_device), normalize=True).detach().cpu()
|
184 |
+
)
|
185 |
+
embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device)
|
186 |
+
assert torch.allclose(
|
187 |
+
F.normalize(embedding_text_labels_norm, dim=0),
|
188 |
+
embedding_text_labels_norm
|
189 |
+
)
|
190 |
+
if args.clip_model_name == 'ViT-B-32':
|
191 |
+
assert embedding_text_labels_norm.shape == (512, 1000), embedding_text_labels_norm.shape
|
192 |
+
elif args.clip_model_name in ('ViT-L-14', 'ViT-L-14-336'):
|
193 |
+
assert embedding_text_labels_norm.shape == (768, 1000), embedding_text_labels_norm.shape
|
194 |
+
else:
|
195 |
+
raise ValueError(f'Unknown model: {args.clip_model_name}')
|
196 |
+
|
197 |
+
model_orig.cpu()
|
198 |
+
model_orig = ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize)
|
199 |
+
if num_gpus > 1:
|
200 |
+
model_orig = torch.nn.DataParallel(model_orig)
|
201 |
+
model_orig.cuda()
|
202 |
+
|
203 |
+
model = ClipVisionModel(model=model.visual, args=args, normalize=normalize)
|
204 |
+
if num_gpus > 1:
|
205 |
+
model = torch.nn.DataParallel(model)
|
206 |
+
model.cuda()
|
207 |
+
|
208 |
+
####################################################### get slot-attention model #########################################################
|
209 |
+
model_slots = model_slots
|
210 |
+
if num_gpus > 1:
|
211 |
+
model_slots = torch.nn.DataParallel(model_slots)
|
212 |
+
proj_head = torch.nn.DataParallel(proj_head)
|
213 |
+
model_slots.cuda()
|
214 |
+
proj_head.cuda()
|
215 |
+
|
216 |
+
# set optimizer (all params have requires_grad=True)
|
217 |
+
params = unwrap_model(model).model.parameters()
|
218 |
+
params_head = unwrap_model(proj_head).parameters()
|
219 |
+
|
220 |
+
if args.opt == 'adamw':
|
221 |
+
optimizer = torch.optim.AdamW(
|
222 |
+
[{'params': params},
|
223 |
+
{'params': params_head}], lr=args.lr, weight_decay=args.wd)
|
224 |
+
elif args.opt == 'sgd':
|
225 |
+
optimizer = torch.optim.SGD(
|
226 |
+
params,
|
227 |
+
lr=args.lr,
|
228 |
+
momentum=args.momentum_sgd,
|
229 |
+
weight_decay=args.wd
|
230 |
+
)
|
231 |
+
else:
|
232 |
+
raise ValueError(f'Optimizer {args.optimizer} not supported.')
|
233 |
+
if args.optimizer_state != '':
|
234 |
+
optimizer.load_state_dict(torch.load(args.optimizer_state))
|
235 |
+
|
236 |
+
# set scheduler
|
237 |
+
scheduler = cosine_lr(optimizer, args.lr, args.warmup, args.steps)
|
238 |
+
|
239 |
+
# compute amount of epochs
|
240 |
+
total_epochs = args.steps / len(dataloader)
|
241 |
+
print(f'train for {total_epochs} epochs')
|
242 |
+
args.total_epochs = total_epochs
|
243 |
+
|
244 |
+
# finetune
|
245 |
+
step_total = args.start_step
|
246 |
+
epoch = 0
|
247 |
+
while step_total < args.steps:
|
248 |
+
step_total = train_one_epoch(
|
249 |
+
step_total,
|
250 |
+
model=model,
|
251 |
+
model_orig=model_orig,
|
252 |
+
model_slots=model_slots,
|
253 |
+
proj_head=proj_head,
|
254 |
+
dataloader=dataloader,
|
255 |
+
dataloader_eval=dataloader_eval,
|
256 |
+
optimizer=optimizer,
|
257 |
+
scheduler=scheduler,
|
258 |
+
embedding_text_labels_norm=embedding_text_labels_norm,
|
259 |
+
normalize=normalize,
|
260 |
+
args=args,
|
261 |
+
epoch=epoch
|
262 |
+
)
|
263 |
+
print(f'Epoch {epoch} done.')
|
264 |
+
epoch += 1
|
265 |
+
|
266 |
+
# save final model
|
267 |
+
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/final.pt')
|
268 |
+
torch.save(unwrap_model(proj_head).model.state_dict(), f'{args.output_dir}/checkpoints/final_proj_head.pt')
|
269 |
+
|
270 |
+
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
|
271 |
+
|
272 |
+
if args.output_dir.endswith('_temp'):
|
273 |
+
# rename temp dir to final dir
|
274 |
+
os.rename(args.output_dir, args.output_dir[:-5])
|
275 |
+
|
276 |
+
class ClipVisionModel(torch.nn.Module):
|
277 |
+
def __init__(self, model, args, normalize):
|
278 |
+
super().__init__()
|
279 |
+
self.model = model
|
280 |
+
self.args = args
|
281 |
+
self.normalize = normalize
|
282 |
+
|
283 |
+
def forward(self, vision, output_normalize, return_all_blocks=True, need_OT=False, object_token=None):
|
284 |
+
vision = self.normalize(vision)
|
285 |
+
embedding, patches = self.model(vision, return_all_blocks=return_all_blocks, need_OT=need_OT, object_token=object_token)
|
286 |
+
if output_normalize:
|
287 |
+
embedding = F.normalize(embedding, dim=-1)
|
288 |
+
return embedding, patches
|
289 |
+
|
290 |
+
|
291 |
+
|
292 |
+
|
293 |
+
|
294 |
+
class ComputeLossWrapper:
|
295 |
+
def __init__(self, embedding_orig, embedding_text_labels_norm, reduction='mean', loss=None,
|
296 |
+
logit_scale=100.):
|
297 |
+
self.embedding_orig = embedding_orig
|
298 |
+
self.embedding_text_labels_norm = embedding_text_labels_norm
|
299 |
+
self.reduction = reduction
|
300 |
+
self.loss_str = loss
|
301 |
+
self.logit_scale = logit_scale
|
302 |
+
|
303 |
+
def __call__(self, embedding, targets):
|
304 |
+
return compute_loss(
|
305 |
+
loss_str=self.loss_str, embedding=embedding, targets=targets,
|
306 |
+
embedding_orig=self.embedding_orig, logit_scale=self.logit_scale,
|
307 |
+
embedding_text_labels_norm=self.embedding_text_labels_norm, reduction=self.reduction
|
308 |
+
)
|
309 |
+
|
310 |
+
def train_one_epoch(
|
311 |
+
step_total, model, model_orig, model_slots, proj_head, dataloader, optimizer, scheduler, normalize,
|
312 |
+
embedding_text_labels_norm, args, epoch, dataloader_eval=None
|
313 |
+
):
|
314 |
+
model_orig.eval()
|
315 |
+
model.train()
|
316 |
+
model_slots.eval()
|
317 |
+
proj_head.train()
|
318 |
+
|
319 |
+
loss_meter = AverageMeter('loss')
|
320 |
+
cos_sim_meter = AverageMeter('cos-sim')
|
321 |
+
acc_meter = AverageMeter('acc')
|
322 |
+
racc_meter = AverageMeter('racc')
|
323 |
+
|
324 |
+
epoch_start_time = time.time()
|
325 |
+
for i, (data, targets) in enumerate(dataloader):
|
326 |
+
is_classification = isinstance(targets, torch.Tensor)
|
327 |
+
data = data.cuda()
|
328 |
+
n_samples = data.shape[0]
|
329 |
+
if is_classification:
|
330 |
+
targets = targets.cuda()
|
331 |
+
|
332 |
+
with torch.no_grad():
|
333 |
+
embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize)
|
334 |
+
reconstruction, slots, masks, x_dinov2 = model_slots(patches_orig) # (B, token, 768)
|
335 |
+
|
336 |
+
object_token = proj_head(slots)
|
337 |
+
|
338 |
+
# loss for the attack
|
339 |
+
loss_inner_wrapper = ComputeLossWrapper(
|
340 |
+
embedding_orig, embedding_text_labels_norm,
|
341 |
+
reduction='none' if args.attack == 'apgd' else 'mean', loss=args.inner_loss,
|
342 |
+
logit_scale=100.
|
343 |
+
)
|
344 |
+
model.eval()
|
345 |
+
|
346 |
+
if args.attack == 'pgd':
|
347 |
+
data_adv = pgd(
|
348 |
+
forward=model,
|
349 |
+
loss_fn=loss_inner_wrapper,
|
350 |
+
data_clean=data,
|
351 |
+
targets=targets,
|
352 |
+
norm=args.norm,
|
353 |
+
eps=args.eps,
|
354 |
+
iterations=args.iterations_adv,
|
355 |
+
stepsize=args.stepsize_adv,
|
356 |
+
output_normalize=args.output_normalize,
|
357 |
+
perturbation=torch.zeros_like(data).uniform_(-args.eps, args.eps).requires_grad_(True),
|
358 |
+
mode='max',
|
359 |
+
verbose=False,
|
360 |
+
need_OT = False
|
361 |
+
)
|
362 |
+
elif args.attack == 'apgd':
|
363 |
+
# apgd currently always applies output normalization
|
364 |
+
data_adv = apgd(
|
365 |
+
model=model,
|
366 |
+
loss_fn=loss_inner_wrapper,
|
367 |
+
x=data,
|
368 |
+
y=targets,
|
369 |
+
norm=args.norm,
|
370 |
+
eps=args.eps,
|
371 |
+
n_iter=args.iterations_adv,
|
372 |
+
verbose=True
|
373 |
+
)
|
374 |
+
elif args.attack == 'none':
|
375 |
+
data_adv = data
|
376 |
+
|
377 |
+
del loss_inner_wrapper
|
378 |
+
model.train()
|
379 |
+
|
380 |
+
embedding_clean, patches_clean = model(data, output_normalize=args.output_normalize)
|
381 |
+
if args.clean_weight > 0.:
|
382 |
+
loss_clean = compute_loss(
|
383 |
+
loss_str=args.loss_clean, embedding=embedding_clean, targets=targets,
|
384 |
+
embedding_orig=embedding_orig, logit_scale=100., embedding_text_labels_norm=None
|
385 |
+
)
|
386 |
+
else:
|
387 |
+
loss_clean = 0.
|
388 |
+
|
389 |
+
embedding_adv, patches_adv = model(data_adv, output_normalize=args.output_normalize, need_OT=True, object_token=object_token)
|
390 |
+
del data, data_adv
|
391 |
+
|
392 |
+
if args.trades:
|
393 |
+
embedding_clean_no_grad = embedding_clean.detach().clone()
|
394 |
+
embedding_orig.cpu()
|
395 |
+
|
396 |
+
loss = compute_loss(
|
397 |
+
loss_str=args.loss, embedding=embedding_adv, targets=targets,
|
398 |
+
embedding_orig=embedding_orig if not args.trades else embedding_clean_no_grad,
|
399 |
+
logit_scale=100., embedding_text_labels_norm=embedding_text_labels_norm
|
400 |
+
)
|
401 |
+
loss_total = args.clean_weight * loss_clean + (1 - args.clean_weight) * loss
|
402 |
+
loss_total.backward()
|
403 |
+
optimizer.step()
|
404 |
+
optimizer.zero_grad()
|
405 |
+
step_total += 1
|
406 |
+
scheduler(step_total)
|
407 |
+
|
408 |
+
with torch.no_grad():
|
409 |
+
# only for logging
|
410 |
+
embedding_orig.cuda()
|
411 |
+
cos_sim_clean = F.cosine_similarity(embedding_clean, embedding_orig, dim=1).mean()
|
412 |
+
cos_sim = F.cosine_similarity(embedding_adv, embedding_orig, dim=1).mean()
|
413 |
+
if is_classification:
|
414 |
+
logits_adv = embedding_adv @ embedding_text_labels_norm
|
415 |
+
racc = compute_acc(logits_adv, targets)
|
416 |
+
embedding_clean_norm = F.normalize(embedding_clean, dim=1)
|
417 |
+
logits_clean = embedding_clean_norm @ embedding_text_labels_norm
|
418 |
+
acc = compute_acc(logits_clean, targets)
|
419 |
+
acc_meter.update(acc, n_samples)
|
420 |
+
racc_meter.update(racc, n_samples)
|
421 |
+
del embedding_clean_norm, embedding_clean
|
422 |
+
else:
|
423 |
+
acc = None
|
424 |
+
racc = None
|
425 |
+
|
426 |
+
loss_meter.update(loss.item(), n_samples)
|
427 |
+
cos_sim_meter.update(cos_sim.item(), n_samples)
|
428 |
+
|
429 |
+
eval_logs = dict()
|
430 |
+
if (step_total-1) % args.eval_freq == 0:
|
431 |
+
# we compute acc and racc (against supervised apgd) on validation data
|
432 |
+
model.eval()
|
433 |
+
data_eval, targets_eval = next(iter(dataloader_eval))
|
434 |
+
data_eval, targets_eval = data_eval.cuda(), targets_eval.cuda()
|
435 |
+
loss_eval_wrapper = ComputeLossWrapper(
|
436 |
+
embedding_orig=None, embedding_text_labels_norm=embedding_text_labels_norm,
|
437 |
+
reduction='none', loss='ce', logit_scale=100.
|
438 |
+
)
|
439 |
+
data_eval_adv = apgd(
|
440 |
+
model=model,
|
441 |
+
loss_fn=loss_eval_wrapper,
|
442 |
+
x=data_eval,
|
443 |
+
y=targets_eval,
|
444 |
+
norm=args.norm,
|
445 |
+
eps=args.eps,
|
446 |
+
n_iter=50,
|
447 |
+
initial_stepsize=0.05 * args.eps if args.clean_weight > 0 else None,
|
448 |
+
verbose=False
|
449 |
+
)
|
450 |
+
with torch.no_grad():
|
451 |
+
embedding_adv_eval_norm, patches_adv_eval_norm = model(data_eval_adv, output_normalize=True) # we set output_normalize to True
|
452 |
+
logits_eval_adv = embedding_adv_eval_norm @ embedding_text_labels_norm
|
453 |
+
racc_eval = compute_acc(logits_eval_adv, targets_eval)
|
454 |
+
embedding_eval_norm, patches_adv_eval_norm = model(data_eval, output_normalize=True)
|
455 |
+
logits_eval = embedding_eval_norm @ embedding_text_labels_norm
|
456 |
+
acc_eval = compute_acc(logits_eval, targets_eval)
|
457 |
+
# note we compute the cosine sim between clean and adv embedding,
|
458 |
+
# not between orig and adv embedding as for training
|
459 |
+
cos_sim_eval = F.cosine_similarity(embedding_adv_eval_norm, embedding_eval_norm, dim=1).mean()
|
460 |
+
eval_logs['eval/racc'] = racc_eval
|
461 |
+
eval_logs['eval/acc'] = acc_eval
|
462 |
+
eval_logs['eval/cos-sim'] = cos_sim_eval
|
463 |
+
print(f'[eval-acc] {acc_eval:.2f} [eval-racc] {racc_eval:.2f} [eval-cos-sim] {cos_sim_eval:.3f}')
|
464 |
+
model.train()
|
465 |
+
del data_eval_adv, data_eval, targets_eval, embedding_adv_eval_norm, logits_eval_adv, embedding_eval_norm, logits_eval
|
466 |
+
|
467 |
+
lr_ = optimizer.param_groups[0].get('lr')
|
468 |
+
if (step_total-1) % args.log_freq == 0:
|
469 |
+
log_str = f'[step] {step_total} [lr] {lr_:.6f} [loss] {loss.item():.6f} [cos-sim] {cos_sim.item():.3f}'
|
470 |
+
if is_classification:
|
471 |
+
log_str += f' [acc] {acc:.2f} [racc] {racc:.2f}'
|
472 |
+
print(log_str)
|
473 |
+
log_data = {
|
474 |
+
'step': step_total,
|
475 |
+
'lr': lr_,
|
476 |
+
'loss': loss.item(),
|
477 |
+
'loss-total': loss_total.item(),
|
478 |
+
'cos-sim-clean': cos_sim_clean.item(),
|
479 |
+
'cos-sim': cos_sim.item(),
|
480 |
+
'acc': acc,
|
481 |
+
'racc': racc,
|
482 |
+
'avg/loss': loss_meter.avg,
|
483 |
+
'avg/cos-sim': cos_sim_meter.avg,
|
484 |
+
'avg/acc': acc_meter.avg,
|
485 |
+
'avg/racc': racc_meter.avg,
|
486 |
+
}
|
487 |
+
log_data.update(eval_logs)
|
488 |
+
if (step_total-1) % (args.log_freq * 10) == 0:
|
489 |
+
# compute expected average epoch time in hours
|
490 |
+
batch_average_time = (time.time() - epoch_start_time) / (i + 1) / (60**2)
|
491 |
+
epoch_average_time = batch_average_time * len(dataloader)
|
492 |
+
this_epoch_remaining = epoch_average_time - \
|
493 |
+
(time.time() - epoch_start_time) / 60**2
|
494 |
+
total_remaining = epoch_average_time * (args.total_epochs - epoch - i / len(dataloader))
|
495 |
+
print(f'[epoch average time] {epoch_average_time:.2f} [this epoch remaining] '
|
496 |
+
f'{this_epoch_remaining:.2f} [total remaining] {total_remaining:.2f}')
|
497 |
+
|
498 |
+
log_data.update({
|
499 |
+
'time/total-remaining': total_remaining,
|
500 |
+
'time/this-epoch-remaining': this_epoch_remaining,
|
501 |
+
'time/epoch-average-time': epoch_average_time,
|
502 |
+
'time/batch-average-time': batch_average_time,
|
503 |
+
'other/epoch': epoch + i / len(dataloader),
|
504 |
+
})
|
505 |
+
wandb.log(log_data)
|
506 |
+
|
507 |
+
# save 10 models over the course of training
|
508 |
+
if args.save_checkpoints and (step_total % (args.steps // 10) == 0):
|
509 |
+
# save model and optimizer state_dict
|
510 |
+
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
511 |
+
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
512 |
+
# every 200 steps, save a fallback model, which gets overwritten
|
513 |
+
if step_total % 200 == 0:
|
514 |
+
torch.save(unwrap_model(model).model.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
515 |
+
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
516 |
+
# remove old fallback models
|
517 |
+
for file in os.listdir(f'{args.output_dir}/checkpoints'):
|
518 |
+
if file.startswith('fallback') and not str(step_total) in file:
|
519 |
+
os.remove(f'{args.output_dir}/checkpoints/{file}')
|
520 |
+
|
521 |
+
if step_total >= args.steps:
|
522 |
+
break
|
523 |
+
|
524 |
+
torch.cuda.empty_cache()
|
525 |
+
return step_total
|
526 |
+
|
527 |
+
|
528 |
+
@torch.no_grad()
|
529 |
+
def compute_acc(logits, targets):
|
530 |
+
preds_clean = logits.max(dim=1)[1].detach()
|
531 |
+
acc = (preds_clean.eq(targets).sum() / targets.shape[0]).item() * 100
|
532 |
+
return acc
|
533 |
+
|
534 |
+
|
535 |
+
def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale,
|
536 |
+
embedding_text_labels_norm=None, reduction='mean'):
|
537 |
+
if loss_str == 'l2':
|
538 |
+
loss = l2(out=embedding, targets=embedding_orig, reduction=reduction)
|
539 |
+
elif loss_str == 'ce':
|
540 |
+
loss = ce(
|
541 |
+
out=embedding @ (logit_scale * embedding_text_labels_norm),
|
542 |
+
targets=targets,
|
543 |
+
reduction=reduction
|
544 |
+
)
|
545 |
+
else:
|
546 |
+
raise ValueError(f'loss {loss_str} not supported')
|
547 |
+
return loss
|
548 |
+
|
549 |
+
def l2(out, targets, reduction='none'):
|
550 |
+
# squared l2 - it does not divide by the latent dimension
|
551 |
+
# should have shape (batch_size, embedding_size)
|
552 |
+
assert out.shape == targets.shape, f'{out.shape} != {targets.shape}'
|
553 |
+
assert out.shape[0] > 1
|
554 |
+
# Compute the element-wise squared error
|
555 |
+
squared_error_batch = F.mse_loss(out, targets, reduction='none')
|
556 |
+
if reduction == 'mean':
|
557 |
+
squared_error_batch = torch.mean(squared_error_batch.sum(dim=1))
|
558 |
+
else:
|
559 |
+
squared_error_batch = squared_error_batch.sum(dim=1)
|
560 |
+
assert squared_error_batch.shape == (out.shape[0],), f'{squared_error_batch.shape} != {(out.shape[0],)}'
|
561 |
+
return squared_error_batch
|
562 |
+
|
563 |
+
def ce(out, targets, reduction='mean'):
|
564 |
+
# out = logits
|
565 |
+
assert out.shape[0] == targets.shape[0], (out.shape, targets.shape)
|
566 |
+
assert out.shape[0] > 1
|
567 |
+
|
568 |
+
return F.cross_entropy(out, targets, reduction=reduction)
|
569 |
+
|
570 |
+
if __name__ == '__main__':
|
571 |
+
# set seeds
|
572 |
+
torch.manual_seed(0)
|
573 |
+
np.random.seed(0)
|
574 |
+
|
575 |
+
# Parse command-line arguments
|
576 |
+
args = parser.parse_args()
|
577 |
+
args.eps /= 255
|
578 |
+
args.stepsize_adv /= 255
|
579 |
+
# make sure there is no string in args that should be a bool
|
580 |
+
assert not any([isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values()]), f'args contains a string that should be a bool: {args}'
|
581 |
+
assert args.eval_freq % args.log_freq == 0, 'eval_freq must be a multiple of log_freq'
|
582 |
+
|
583 |
+
if args.devices != '':
|
584 |
+
# set cuda visible devices
|
585 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
|
586 |
+
|
587 |
+
num_gpus = torch.cuda.device_count()
|
588 |
+
if num_gpus > 1:
|
589 |
+
print(f'Number of GPUs available: {num_gpus}')
|
590 |
+
else:
|
591 |
+
print('No multiple GPUs available.')
|
592 |
+
|
593 |
+
# set model name and output dir
|
594 |
+
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=5))
|
595 |
+
args.finetuned_model_name = f'{args.clip_model_name}_{args.pretrained}_{args.dataset}_{args.loss}_{args.dataset}_{args.experiment_name}_{random_str}'
|
596 |
+
args.finetuned_model_name = args.finetuned_model_name.replace('/', '_')
|
597 |
+
args.output_dir = os.path.join(args.output_dir, args.finetuned_model_name)
|
598 |
+
# run
|
599 |
+
main(args)
|
train/apgd_train.py
CHANGED
@@ -178,7 +178,7 @@ def apgd_train(model, x, y, norm, eps, n_iter=10, use_rs=False, loss_fn=None,
|
|
178 |
# grad = torch.zeros_like(x)
|
179 |
# for _ in range(self.eot_iter)
|
180 |
# with torch.enable_grad()
|
181 |
-
logits = model(x_adv, output_normalize=True)
|
182 |
loss_indiv = loss_fn(logits, y)
|
183 |
loss = loss_indiv.sum()
|
184 |
# grad += torch.autograd.grad(loss, [x_adv])[0].detach()
|
@@ -285,7 +285,7 @@ def apgd_train(model, x, y, norm, eps, n_iter=10, use_rs=False, loss_fn=None,
|
|
285 |
# grad = torch.zeros_like(x)
|
286 |
# for _ in range(self.eot_iter)
|
287 |
# with torch.enable_grad()
|
288 |
-
logits = model(x_adv, output_normalize=True)
|
289 |
loss_indiv = loss_fn(logits, y)
|
290 |
loss = loss_indiv.sum()
|
291 |
|
|
|
178 |
# grad = torch.zeros_like(x)
|
179 |
# for _ in range(self.eot_iter)
|
180 |
# with torch.enable_grad()
|
181 |
+
logits, patch = model(x_adv, output_normalize=True)
|
182 |
loss_indiv = loss_fn(logits, y)
|
183 |
loss = loss_indiv.sum()
|
184 |
# grad += torch.autograd.grad(loss, [x_adv])[0].detach()
|
|
|
285 |
# grad = torch.zeros_like(x)
|
286 |
# for _ in range(self.eot_iter)
|
287 |
# with torch.enable_grad()
|
288 |
+
logits, patch = model(x_adv, output_normalize=True)
|
289 |
loss_indiv = loss_fn(logits, y)
|
290 |
loss = loss_indiv.sum()
|
291 |
|
train/pgd_train.py
CHANGED
@@ -15,7 +15,8 @@ def pgd(
|
|
15 |
perturbation=None,
|
16 |
mode='min',
|
17 |
momentum=0.9,
|
18 |
-
verbose=False
|
|
|
19 |
):
|
20 |
"""
|
21 |
Minimize or maximize given loss
|
@@ -29,7 +30,7 @@ def pgd(
|
|
29 |
for i in range(iterations):
|
30 |
perturbation.requires_grad = True
|
31 |
with torch.enable_grad():
|
32 |
-
out = forward(data_clean + perturbation, output_normalize=output_normalize)
|
33 |
loss = loss_fn(out, targets)
|
34 |
if verbose:
|
35 |
print(f'[{i}] {loss.item():.5f}')
|
|
|
15 |
perturbation=None,
|
16 |
mode='min',
|
17 |
momentum=0.9,
|
18 |
+
verbose=False,
|
19 |
+
need_OT=False
|
20 |
):
|
21 |
"""
|
22 |
Minimize or maximize given loss
|
|
|
30 |
for i in range(iterations):
|
31 |
perturbation.requires_grad = True
|
32 |
with torch.enable_grad():
|
33 |
+
out, patch_out = forward(data_clean + perturbation, output_normalize=output_normalize, need_OT=need_OT)
|
34 |
loss = loss_fn(out, targets)
|
35 |
if verbose:
|
36 |
print(f'[{i}] {loss.item():.5f}')
|
train/training_clip+dinov2_slots.py
ADDED
@@ -0,0 +1,543 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
|
4 |
+
from train.datasets import COCOFlickrDataset, ImageNetDataset
|
5 |
+
from CLIP_eval.eval_utils import load_clip_model
|
6 |
+
|
7 |
+
sys.path.append("open_flamingo")
|
8 |
+
import os
|
9 |
+
import shutil
|
10 |
+
import time
|
11 |
+
import string
|
12 |
+
import random
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import open_clip
|
16 |
+
import torch
|
17 |
+
import torch.nn.functional as F
|
18 |
+
from torch.utils.data import DataLoader
|
19 |
+
from training.scheduler import cosine_lr
|
20 |
+
from torchvision import transforms
|
21 |
+
from open_flamingo.eval.classification_utils import IMAGENET_1K_CLASS_ID_TO_LABEL
|
22 |
+
from train.pgd_train import pgd
|
23 |
+
from train.apgd_train import apgd_train as apgd
|
24 |
+
import wandb
|
25 |
+
from train.utils import init_wandb, AverageMeter
|
26 |
+
from train.sam_data import SamData
|
27 |
+
from open_flamingo.eval.models.utils import unwrap_model
|
28 |
+
from train.utils import str2bool
|
29 |
+
from torch.hub import load
|
30 |
+
|
31 |
+
|
32 |
+
|
33 |
+
RANDOM_SEED = 42 # any random number
|
34 |
+
def set_seed(seed):
|
35 |
+
random.seed(seed)
|
36 |
+
np.random.seed(seed)
|
37 |
+
torch.manual_seed(seed) # CPU
|
38 |
+
torch.cuda.manual_seed(seed) # GPU
|
39 |
+
torch.cuda.manual_seed_all(seed) # All GPU
|
40 |
+
os.environ['PYTHONHASHSEED'] = str(seed) # 禁止hash随机化
|
41 |
+
torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
|
42 |
+
torch.backends.cudnn.benchmark = False # True的话会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。False保证实验结果可复现
|
43 |
+
set_seed(RANDOM_SEED)
|
44 |
+
|
45 |
+
from slots.DINOSAUR import DINOSAURpp
|
46 |
+
import matplotlib.pyplot as plt
|
47 |
+
from einops import rearrange, repeat
|
48 |
+
|
49 |
+
|
50 |
+
|
51 |
+
import argparse
|
52 |
+
|
53 |
+
parser = argparse.ArgumentParser()
|
54 |
+
parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32')
|
55 |
+
parser.add_argument('--pretrained', type=str, default='openai')
|
56 |
+
parser.add_argument('--dataset', type=str, default='imagenet')
|
57 |
+
parser.add_argument('--template', type=str, default='std')
|
58 |
+
parser.add_argument('--imagenet_root', type=str, default='/mnt/datasets/imagenet', help='Imagenet dataset root directory')
|
59 |
+
parser.add_argument('--output_normalize', type=str2bool, default=False, help='Whether the embedding is normalized')
|
60 |
+
parser.add_argument('--start_step', type=int, default=0, help='Start step for training')
|
61 |
+
parser.add_argument('--optimizer_state', type=str, default='', help='Optimizer state file path')
|
62 |
+
parser.add_argument('--steps', type=int, default=20000, help='Number of training steps')
|
63 |
+
parser.add_argument('--warmup', type=int, default=14000, help='Warmup steps')
|
64 |
+
parser.add_argument('--batch_size', type=int, default=256)
|
65 |
+
parser.add_argument('--loss', type=str, default='l2', help='ce, l2')
|
66 |
+
parser.add_argument('--loss_clean', type=str, default='none', help='ce, l2')
|
67 |
+
parser.add_argument('--clean_weight', type=float, default=0., help='Weight for clean loss')
|
68 |
+
parser.add_argument('--trades', type=str2bool, default=False, help='Use TRADES')
|
69 |
+
parser.add_argument('--opt', type=str, default='adamw', help='Optimizer type; sgd, adamw')
|
70 |
+
parser.add_argument('--momentum_sgd', type=float, default=0.9, help='Momentum for SGD optimizer')
|
71 |
+
parser.add_argument('--lr', type=float, default=1e-5, help='Learning rate')
|
72 |
+
parser.add_argument('--wd', type=float, default=1e-4, help='Weight decay')
|
73 |
+
parser.add_argument('--attack', type=str, default='apgd', help='Adversarial attack type')
|
74 |
+
parser.add_argument('--inner_loss', type=str, default='l2', help='Inner loss function for adversarial training')
|
75 |
+
parser.add_argument('--norm', type=str, default='linf', help='Norm for adversarial perturbation')
|
76 |
+
parser.add_argument('--eps', type=float, default=4, help='Epsilon for adversarial perturbation')
|
77 |
+
parser.add_argument('--iterations_adv', type=int, default=10, help='Iterations for adversarial attack')
|
78 |
+
parser.add_argument('--stepsize_adv', type=float, default=1., help='Step size for adversarial attack (no effect for apgd)')
|
79 |
+
parser.add_argument('--wandb', type=str2bool, default=True, help='Use Weights & Biases for logging')
|
80 |
+
parser.add_argument('--experiment_name', type=str, default='')
|
81 |
+
parser.add_argument('--overwrite', type=str2bool, default=False, help='Overwrite existing directory')
|
82 |
+
parser.add_argument('--log_freq', type=int, default=1, help='Logging frequency')
|
83 |
+
parser.add_argument('--eval_freq', type=int, default=50, help='Evaluation frequency')
|
84 |
+
parser.add_argument('--output_dir', type=str, default='', help='Output directory')
|
85 |
+
parser.add_argument('--save_checkpoints', type=str2bool, default=True, help='Save 10 training checkpoints')
|
86 |
+
parser.add_argument('--devices', type=str, default='', help='Device IDs for CUDA')
|
87 |
+
|
88 |
+
|
89 |
+
def main(args):
|
90 |
+
# setup wandb
|
91 |
+
if args.wandb:
|
92 |
+
init_wandb(
|
93 |
+
project_name='clip-finetune',
|
94 |
+
model_name=args.finetuned_model_name,
|
95 |
+
config=vars(args)
|
96 |
+
)
|
97 |
+
else:
|
98 |
+
wandb.init(mode='disabled')
|
99 |
+
|
100 |
+
# print args
|
101 |
+
print(f"Arguments:\n{'-' * 20}")
|
102 |
+
for arg, value in vars(args).items():
|
103 |
+
print(f"{arg}: {value}")
|
104 |
+
print(f"{'-' * 20}")
|
105 |
+
|
106 |
+
# setup dirs
|
107 |
+
if args.overwrite:
|
108 |
+
shutil.rmtree(args.output_dir, ignore_errors=True)
|
109 |
+
os.makedirs(os.path.join(args.output_dir, 'checkpoints'), exist_ok=False)
|
110 |
+
|
111 |
+
# write args to file
|
112 |
+
with open(os.path.join(args.output_dir, 'args.txt'), 'w') as f:
|
113 |
+
f.write(str(args))
|
114 |
+
|
115 |
+
main_device = 0
|
116 |
+
# get models
|
117 |
+
from open_clip.model import CLIPVisionCfg
|
118 |
+
CLIPVisionCfg.output_tokens = True
|
119 |
+
model_orig, _, image_processor = open_clip.create_model_and_transforms(
|
120 |
+
args.clip_model_name, pretrained='openai'#, output_tokens=True # 可选 output_tokens=True,返回token + patches
|
121 |
+
)
|
122 |
+
# Remove the Normalize transform by creating a new Compose object
|
123 |
+
preprocessor_without_normalize = transforms.Compose(image_processor.transforms[:-1])
|
124 |
+
normalize = image_processor.transforms[-1]
|
125 |
+
del image_processor
|
126 |
+
print(f'[preprocessor_without_normalize] {preprocessor_without_normalize}')
|
127 |
+
|
128 |
+
####################################################### get slot-attention model #########################################################
|
129 |
+
cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
|
130 |
+
model_slots = DINOSAURpp(cfg_dict)
|
131 |
+
|
132 |
+
# get data
|
133 |
+
if args.dataset == 'imagenet':
|
134 |
+
dataset = ImageNetDataset(
|
135 |
+
root=args.imagenet_root + '/train',
|
136 |
+
transform=preprocessor_without_normalize,
|
137 |
+
)
|
138 |
+
|
139 |
+
elif args.dataset == 'segment_anything':
|
140 |
+
dataset = SamData('/data/naman_deep_singh/datasets/newSAM', transform=preprocessor_without_normalize)
|
141 |
+
|
142 |
+
print(dataset.__len__())
|
143 |
+
elif args.dataset == 'coco':
|
144 |
+
if os.path.exists('/mnt/datasets/coco'):
|
145 |
+
image_dir_path = '/mnt/datasets/coco/train2017'
|
146 |
+
annotations_path = '/mnt/datasets/coco/annotations/captions_train2017.json'
|
147 |
+
elif os.path.exists('/mnt/lustre'):
|
148 |
+
image_dir_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/train2017'
|
149 |
+
annotations_path = '/mnt/lustre/hein/cschlarmann37/datasets/coco/annotations/captions_train2017.json'
|
150 |
+
else:
|
151 |
+
raise ValueError('COCO dataset not found')
|
152 |
+
dataset = COCOFlickrDataset(
|
153 |
+
image_dir_path=image_dir_path,
|
154 |
+
annotations_path=annotations_path,
|
155 |
+
transform=preprocessor_without_normalize
|
156 |
+
)
|
157 |
+
dataset_eval = ImageNetDataset(
|
158 |
+
root=args.imagenet_root + '/val',
|
159 |
+
transform=preprocessor_without_normalize,
|
160 |
+
)
|
161 |
+
dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
|
162 |
+
dataloader_eval = DataLoader(dataset_eval, batch_size=args.batch_size, shuffle=True, num_workers=8, drop_last=True)
|
163 |
+
|
164 |
+
# Get text label embeddings of all ImageNet classes
|
165 |
+
if args.template == 'std':
|
166 |
+
template = 'This is a photo of a {}'
|
167 |
+
elif args.template == 'blurry':
|
168 |
+
template = 'This is a blurry photo of a {}'
|
169 |
+
else:
|
170 |
+
raise ValueError(f'Unknown template: {args.template}')
|
171 |
+
print(f'template: {template}')
|
172 |
+
texts = [template.format(c) for c in IMAGENET_1K_CLASS_ID_TO_LABEL.values()]
|
173 |
+
text_tokens = open_clip.tokenize(texts)
|
174 |
+
model_orig.to(main_device)
|
175 |
+
with torch.no_grad():
|
176 |
+
embedding_text_labels_norm = []
|
177 |
+
for el in (text_tokens[:500], text_tokens[500:]):
|
178 |
+
# we need to split the text tokens into two batches because otherwise we run out of memory
|
179 |
+
# note that we are accessing the model directly here, not the CustomModel wrapper
|
180 |
+
# thus its always normalizing the text embeddings
|
181 |
+
embedding_text_labels_norm.append(
|
182 |
+
model_orig.encode_text(el.to(main_device), normalize=True).detach().cpu()
|
183 |
+
)
|
184 |
+
embedding_text_labels_norm = torch.cat(embedding_text_labels_norm).T.to(main_device)
|
185 |
+
assert torch.allclose(
|
186 |
+
F.normalize(embedding_text_labels_norm, dim=0),
|
187 |
+
embedding_text_labels_norm
|
188 |
+
)
|
189 |
+
if args.clip_model_name == 'ViT-B-32':
|
190 |
+
assert embedding_text_labels_norm.shape == (512, 1000), embedding_text_labels_norm.shape
|
191 |
+
elif args.clip_model_name in ('ViT-L-14', 'ViT-L-14-336'):
|
192 |
+
assert embedding_text_labels_norm.shape == (768, 1000), embedding_text_labels_norm.shape
|
193 |
+
else:
|
194 |
+
raise ValueError(f'Unknown model: {args.clip_model_name}')
|
195 |
+
|
196 |
+
model_orig.cpu()
|
197 |
+
model_orig = ClipVisionModel(model=model_orig.visual, args=args, normalize=normalize)
|
198 |
+
if num_gpus > 1:
|
199 |
+
model_orig = torch.nn.DataParallel(model_orig)
|
200 |
+
model_orig.cuda()
|
201 |
+
|
202 |
+
model_slots = model_slots
|
203 |
+
if num_gpus > 1:
|
204 |
+
model_slots = torch.nn.DataParallel(model_slots)
|
205 |
+
model_slots.cuda()
|
206 |
+
|
207 |
+
|
208 |
+
|
209 |
+
####################################################### add dino v2 as target model #####################################################
|
210 |
+
target_backbone_name = 'dinov2_vitb14'
|
211 |
+
model_dinov2 = load('/home/tly/.cache/torch/hub/facebookresearch_dinov2_main', target_backbone_name, source='local') # /home/ly/.cache/torch/hub/facebookresearch_dinov2_main facebookresearch/dinov2
|
212 |
+
# model_dinov2 = load('facebookresearch/dinov2', target_backbone_name) # /home/ly/.cache/torch/hub/facebookresearch_dinov2_main facebookresearch/dinov2
|
213 |
+
|
214 |
+
proj_head = torch.nn.Linear(1024, 768)
|
215 |
+
if num_gpus > 1:
|
216 |
+
model_dinov2 = torch.nn.DataParallel(model_dinov2)
|
217 |
+
proj_head = torch.nn.DataParallel(proj_head)
|
218 |
+
model_dinov2.cuda()
|
219 |
+
proj_head.cuda()
|
220 |
+
|
221 |
+
# set optimizer (all params have requires_grad=True)
|
222 |
+
params_slots = unwrap_model(model_slots).parameters()
|
223 |
+
params_head = unwrap_model(proj_head).parameters()
|
224 |
+
|
225 |
+
if args.opt == 'adamw':
|
226 |
+
optimizer = torch.optim.AdamW(
|
227 |
+
[{'params': params_slots},
|
228 |
+
{'params': params_head}], lr=args.lr, weight_decay=args.wd)
|
229 |
+
elif args.opt == 'sgd':
|
230 |
+
optimizer = torch.optim.SGD(
|
231 |
+
[{'params': params_slots},
|
232 |
+
{'params': params_head}],
|
233 |
+
lr=args.lr,
|
234 |
+
momentum=args.momentum_sgd,
|
235 |
+
weight_decay=args.wd
|
236 |
+
)
|
237 |
+
else:
|
238 |
+
raise ValueError(f'Optimizer {args.optimizer} not supported.')
|
239 |
+
if args.optimizer_state != '':
|
240 |
+
optimizer.load_state_dict(torch.load(args.optimizer_state))
|
241 |
+
|
242 |
+
# set scheduler
|
243 |
+
scheduler = cosine_lr(optimizer, args.lr, args.warmup, args.steps)
|
244 |
+
|
245 |
+
# compute amount of epochs
|
246 |
+
total_epochs = args.steps / len(dataloader)
|
247 |
+
print(f'train for {total_epochs} epochs')
|
248 |
+
args.total_epochs = total_epochs
|
249 |
+
|
250 |
+
# finetune
|
251 |
+
step_total = args.start_step
|
252 |
+
epoch = 0
|
253 |
+
while step_total < args.steps:
|
254 |
+
step_total = train_one_epoch_slots(
|
255 |
+
step_total,
|
256 |
+
model_slots=model_slots,
|
257 |
+
model_orig=model_orig,
|
258 |
+
model_dinov2=model_dinov2,
|
259 |
+
proj_head=proj_head,
|
260 |
+
dataloader=dataloader,
|
261 |
+
dataloader_eval=dataloader_eval,
|
262 |
+
optimizer=optimizer,
|
263 |
+
scheduler=scheduler,
|
264 |
+
embedding_text_labels_norm=embedding_text_labels_norm,
|
265 |
+
normalize=normalize,
|
266 |
+
args=args,
|
267 |
+
epoch=epoch
|
268 |
+
)
|
269 |
+
print(f'Epoch {epoch} done.')
|
270 |
+
epoch += 1
|
271 |
+
|
272 |
+
# save final model
|
273 |
+
torch.save(unwrap_model(model_slots).state_dict(), f'{args.output_dir}/checkpoints/final.pt')
|
274 |
+
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/final_opt.pt')
|
275 |
+
|
276 |
+
if args.output_dir.endswith('_temp'):
|
277 |
+
# rename temp dir to final dir
|
278 |
+
os.rename(args.output_dir, args.output_dir[:-5])
|
279 |
+
|
280 |
+
class ClipVisionModel(torch.nn.Module):
|
281 |
+
def __init__(self, model, args, normalize):
|
282 |
+
super().__init__()
|
283 |
+
self.model = model
|
284 |
+
self.args = args
|
285 |
+
self.normalize = normalize
|
286 |
+
|
287 |
+
def forward(self, vision, output_normalize, return_all_blocks=False):
|
288 |
+
vision = self.normalize(vision)
|
289 |
+
embedding, patches = self.model(vision, return_all_blocks=return_all_blocks)
|
290 |
+
if output_normalize:
|
291 |
+
embedding = F.normalize(embedding, dim=-1)
|
292 |
+
return embedding, patches
|
293 |
+
|
294 |
+
|
295 |
+
class ComputeLossWrapper:
|
296 |
+
def __init__(self, embedding_orig, embedding_text_labels_norm, reduction='mean', loss=None,
|
297 |
+
logit_scale=100.):
|
298 |
+
self.embedding_orig = embedding_orig
|
299 |
+
self.embedding_text_labels_norm = embedding_text_labels_norm
|
300 |
+
self.reduction = reduction
|
301 |
+
self.loss_str = loss
|
302 |
+
self.logit_scale = logit_scale
|
303 |
+
|
304 |
+
def __call__(self, embedding, targets):
|
305 |
+
return compute_loss(
|
306 |
+
loss_str=self.loss_str, embedding=embedding, targets=targets,
|
307 |
+
embedding_orig=self.embedding_orig, logit_scale=self.logit_scale,
|
308 |
+
embedding_text_labels_norm=self.embedding_text_labels_norm, reduction=self.reduction
|
309 |
+
)
|
310 |
+
|
311 |
+
def mean_flat(x):
|
312 |
+
"""
|
313 |
+
Take the mean over all non-batch dimensions.
|
314 |
+
"""
|
315 |
+
return torch.mean(x, dim=list(range(1, len(x.size()))))
|
316 |
+
|
317 |
+
|
318 |
+
def train_one_epoch_slots(
|
319 |
+
step_total, model_slots, model_orig, model_dinov2, proj_head, dataloader, optimizer, scheduler, normalize,
|
320 |
+
embedding_text_labels_norm, args, epoch, dataloader_eval=None
|
321 |
+
):
|
322 |
+
model_orig.eval()
|
323 |
+
model_slots.train()
|
324 |
+
|
325 |
+
MSEFunc = torch.nn.MSELoss()
|
326 |
+
|
327 |
+
|
328 |
+
loss_meter = AverageMeter('loss')
|
329 |
+
|
330 |
+
|
331 |
+
epoch_start_time = time.time()
|
332 |
+
for i, (data, targets) in enumerate(dataloader):
|
333 |
+
is_classification = isinstance(targets, torch.Tensor)
|
334 |
+
data = data.cuda()
|
335 |
+
n_samples = data.shape[0]
|
336 |
+
if is_classification:
|
337 |
+
targets = targets.cuda()
|
338 |
+
|
339 |
+
with torch.no_grad():
|
340 |
+
embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize, return_all_blocks=True)
|
341 |
+
feat_dinov2_dict = model_dinov2(data, is_training=True)
|
342 |
+
|
343 |
+
if num_gpus > 1:
|
344 |
+
patches_orig = model_orig.module.model.ln_pre(patches_orig)
|
345 |
+
else:
|
346 |
+
patches_orig = model_orig.model.ln_pre(patches_orig)
|
347 |
+
|
348 |
+
|
349 |
+
reconstruction, slots, masks, x_dinov2 = model_slots(patches_orig) # (B, token, 768)
|
350 |
+
x_dinov2 = torch.sum(x_dinov2, dim=1)
|
351 |
+
|
352 |
+
# change target to dino features
|
353 |
+
# reconstruction_to_dinov2 = proj_head(reconstruction)
|
354 |
+
|
355 |
+
b, hw, c = patches_orig.shape
|
356 |
+
h, w = int(np.sqrt(hw)), int(np.sqrt(hw))
|
357 |
+
k = slots.size(1)
|
358 |
+
c_dinov2 = feat_dinov2_dict['x_norm_patchtokens'].size(-1)
|
359 |
+
|
360 |
+
reconstruction = rearrange(reconstruction, 'b (h w) c -> b c h w', h=h, w=w)
|
361 |
+
masks = rearrange(masks, 'b k (h w) -> b k h w', h=h, w=w, k=k)
|
362 |
+
patches_orig = rearrange(patches_orig, 'b (h w) c -> b c h w', h=h, w=w)
|
363 |
+
|
364 |
+
|
365 |
+
# loss for the attack
|
366 |
+
loss = MSEFunc(reconstruction, patches_orig)
|
367 |
+
x_dinov2 = torch.nn.functional.normalize(x_dinov2, dim=-1)
|
368 |
+
feat_dinov2 = torch.nn.functional.normalize(feat_dinov2_dict['x_norm_patchtokens'], dim=-1)
|
369 |
+
# x_dinov2 = torch.nn.functional.normalize(x_dinov2, dim=-1)
|
370 |
+
# feat_dinov2 = torch.nn.functional.normalize(feat_dinov2_dict['x_norm_patchtokens'], dim=-1)
|
371 |
+
# loss_dinov2 = MSEFunc(feat_dinov2, x_dinov2)#reconstruction_to_dinov2)
|
372 |
+
loss_dinov2 = mean_flat(-(x_dinov2 * feat_dinov2).sum(dim=-1)).mean()
|
373 |
+
|
374 |
+
loss_total = loss + 0.0001*loss_dinov2
|
375 |
+
|
376 |
+
|
377 |
+
loss_total.backward()
|
378 |
+
optimizer.step()
|
379 |
+
optimizer.zero_grad()
|
380 |
+
step_total += 1
|
381 |
+
scheduler(step_total)
|
382 |
+
|
383 |
+
lr_ = optimizer.param_groups[0].get('lr')
|
384 |
+
if (step_total-1) % args.log_freq == 0:
|
385 |
+
log_str = f'[step] {step_total} [lr] {lr_:.6f} [loss] {loss.item():.6f} [loss_dinov2] {loss_dinov2.item():.6f}'
|
386 |
+
print(log_str)
|
387 |
+
log_data = {
|
388 |
+
'step': step_total,
|
389 |
+
'lr': lr_,
|
390 |
+
'loss': loss.item(),
|
391 |
+
'loss_dinov2': loss_dinov2.item(),
|
392 |
+
'loss-total': loss_total.item(),
|
393 |
+
'avg/loss': loss_meter.avg,
|
394 |
+
}
|
395 |
+
if (step_total-1) % (args.log_freq * 10) == 0:
|
396 |
+
# compute expected average epoch time in hours
|
397 |
+
batch_average_time = (time.time() - epoch_start_time) / (i + 1) / (60**2)
|
398 |
+
epoch_average_time = batch_average_time * len(dataloader)
|
399 |
+
this_epoch_remaining = epoch_average_time - \
|
400 |
+
(time.time() - epoch_start_time) / 60**2
|
401 |
+
total_remaining = epoch_average_time * (args.total_epochs - epoch - i / len(dataloader))
|
402 |
+
print(f'[epoch average time] {epoch_average_time:.2f} [this epoch remaining] '
|
403 |
+
f'{this_epoch_remaining:.2f} [total remaining] {total_remaining:.2f}')
|
404 |
+
|
405 |
+
log_data.update({
|
406 |
+
'time/total-remaining': total_remaining,
|
407 |
+
'time/this-epoch-remaining': this_epoch_remaining,
|
408 |
+
'time/epoch-average-time': epoch_average_time,
|
409 |
+
'time/batch-average-time': batch_average_time,
|
410 |
+
'other/epoch': epoch + i / len(dataloader),
|
411 |
+
})
|
412 |
+
wandb.log(log_data)
|
413 |
+
|
414 |
+
# save 10 models over the course of training
|
415 |
+
if args.save_checkpoints and (step_total % (args.steps // 10) == 0):
|
416 |
+
# save model and optimizer state_dict
|
417 |
+
torch.save(unwrap_model(model_slots).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}.pt')
|
418 |
+
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_proj.pt')
|
419 |
+
|
420 |
+
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/step_{step_total}_opt.pt')
|
421 |
+
# every 200 steps, save a fallback model, which gets overwritten
|
422 |
+
if step_total % 200 == 0:
|
423 |
+
torch.save(unwrap_model(model_slots).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}.pt')
|
424 |
+
torch.save(unwrap_model(proj_head).state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_proj.pt')
|
425 |
+
|
426 |
+
torch.save(optimizer.state_dict(), f'{args.output_dir}/checkpoints/fallback_{step_total}_opt.pt')
|
427 |
+
# remove old fallback models
|
428 |
+
for file in os.listdir(f'{args.output_dir}/checkpoints'):
|
429 |
+
if file.startswith('fallback') and not str(step_total) in file:
|
430 |
+
os.remove(f'{args.output_dir}/checkpoints/{file}')
|
431 |
+
|
432 |
+
######################################################## Save ori Image and recon Image ########################
|
433 |
+
# if epoch % 5 == 0:
|
434 |
+
save_pics_path = os.path.join(args.output_dir, 'slots_recons')
|
435 |
+
|
436 |
+
recon_pic_save_path = os.path.join(save_pics_path, args.dataset)
|
437 |
+
os.makedirs(recon_pic_save_path, exist_ok=True)
|
438 |
+
|
439 |
+
plt.imshow(reconstruction[0, 0].detach().cpu().numpy())
|
440 |
+
save_name = 'recon_pic_steps{}.png'.format(step_total)
|
441 |
+
plt.savefig(os.path.join(recon_pic_save_path, save_name))
|
442 |
+
|
443 |
+
plt.imshow(patches_orig[0, 0].detach().cpu().numpy())
|
444 |
+
save_name = 'recon_pic_steps{}_feat.png'.format(step_total)
|
445 |
+
plt.savefig(os.path.join(recon_pic_save_path, save_name))
|
446 |
+
|
447 |
+
|
448 |
+
plt.imshow(rearrange(x_dinov2, 'b (h w) c_dinov2 -> b c_dinov2 h w', h=h, w=w)[0, 0].detach().cpu().numpy())
|
449 |
+
save_name = 'recon_pic_steps{}_dinov2.png'.format(step_total)
|
450 |
+
plt.savefig(os.path.join(recon_pic_save_path, save_name))
|
451 |
+
|
452 |
+
plt.imshow(rearrange(feat_dinov2_dict['x_norm_patchtokens'], 'b (h w) c_dinov2 -> b c_dinov2 h w', h=h, w=w)[0, 0].detach().cpu().numpy())
|
453 |
+
save_name = 'recon_pic_steps{}_feat_dinov2.png'.format(step_total)
|
454 |
+
plt.savefig(os.path.join(recon_pic_save_path, save_name))
|
455 |
+
|
456 |
+
plt.imshow(data[0].permute(1, 2, 0).detach().cpu().numpy())
|
457 |
+
save_name = 'recon_pic_steps{}_ori.png'.format(step_total)
|
458 |
+
plt.savefig(os.path.join(recon_pic_save_path, save_name))
|
459 |
+
|
460 |
+
plt.close('all')
|
461 |
+
|
462 |
+
if step_total >= args.steps:
|
463 |
+
break
|
464 |
+
|
465 |
+
# torch.cuda.empty_cache()
|
466 |
+
|
467 |
+
|
468 |
+
|
469 |
+
return step_total
|
470 |
+
|
471 |
+
|
472 |
+
@torch.no_grad()
|
473 |
+
def compute_acc(logits, targets):
|
474 |
+
preds_clean = logits.max(dim=1)[1].detach()
|
475 |
+
acc = (preds_clean.eq(targets).sum() / targets.shape[0]).item() * 100
|
476 |
+
return acc
|
477 |
+
|
478 |
+
|
479 |
+
def compute_loss(loss_str, embedding, targets, embedding_orig, logit_scale,
|
480 |
+
embedding_text_labels_norm=None, reduction='mean'):
|
481 |
+
if loss_str == 'l2':
|
482 |
+
loss = l2(out=embedding, targets=embedding_orig, reduction=reduction)
|
483 |
+
elif loss_str == 'ce':
|
484 |
+
loss = ce(
|
485 |
+
out=embedding @ (logit_scale * embedding_text_labels_norm),
|
486 |
+
targets=targets,
|
487 |
+
reduction=reduction
|
488 |
+
)
|
489 |
+
else:
|
490 |
+
raise ValueError(f'loss {loss_str} not supported')
|
491 |
+
return loss
|
492 |
+
|
493 |
+
def l2(out, targets, reduction='none'):
|
494 |
+
# squared l2 - it does not divide by the latent dimension
|
495 |
+
# should have shape (batch_size, embedding_size)
|
496 |
+
assert out.shape == targets.shape, f'{out.shape} != {targets.shape}'
|
497 |
+
assert out.shape[0] > 1
|
498 |
+
# Compute the element-wise squared error
|
499 |
+
squared_error_batch = F.mse_loss(out, targets, reduction='none')
|
500 |
+
if reduction == 'mean':
|
501 |
+
squared_error_batch = torch.mean(squared_error_batch.sum(dim=1))
|
502 |
+
else:
|
503 |
+
squared_error_batch = squared_error_batch.sum(dim=1)
|
504 |
+
assert squared_error_batch.shape == (out.shape[0],), f'{squared_error_batch.shape} != {(out.shape[0],)}'
|
505 |
+
return squared_error_batch
|
506 |
+
|
507 |
+
def ce(out, targets, reduction='mean'):
|
508 |
+
# out = logits
|
509 |
+
assert out.shape[0] == targets.shape[0], (out.shape, targets.shape)
|
510 |
+
assert out.shape[0] > 1
|
511 |
+
|
512 |
+
return F.cross_entropy(out, targets, reduction=reduction)
|
513 |
+
|
514 |
+
if __name__ == '__main__':
|
515 |
+
# set seeds
|
516 |
+
torch.manual_seed(0)
|
517 |
+
np.random.seed(0)
|
518 |
+
|
519 |
+
# Parse command-line arguments
|
520 |
+
args = parser.parse_args()
|
521 |
+
args.eps /= 255
|
522 |
+
args.stepsize_adv /= 255
|
523 |
+
# make sure there is no string in args that should be a bool
|
524 |
+
assert not any([isinstance(x, str) and x in ['True', 'False'] for x in args.__dict__.values()]), f'args contains a string that should be a bool: {args}'
|
525 |
+
assert args.eval_freq % args.log_freq == 0, 'eval_freq must be a multiple of log_freq'
|
526 |
+
|
527 |
+
if args.devices != '':
|
528 |
+
# set cuda visible devices
|
529 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = args.devices
|
530 |
+
|
531 |
+
num_gpus = torch.cuda.device_count()
|
532 |
+
if num_gpus > 1:
|
533 |
+
print(f'Number of GPUs available: {num_gpus}')
|
534 |
+
else:
|
535 |
+
print('No multiple GPUs available.')
|
536 |
+
|
537 |
+
# set model name and output dir
|
538 |
+
random_str = ''.join(random.choices(string.ascii_letters + string.digits, k=5))
|
539 |
+
args.finetuned_model_name = f'{args.clip_model_name}_{args.pretrained}_{args.dataset}_{args.loss}_{args.dataset}_{args.experiment_name}_{random_str}'
|
540 |
+
args.finetuned_model_name = args.finetuned_model_name.replace('/', '_')
|
541 |
+
args.output_dir = os.path.join(args.output_dir, args.finetuned_model_name)
|
542 |
+
# run
|
543 |
+
main(args)
|
train/training_clip_slots.py
CHANGED
@@ -27,6 +27,18 @@ from train.sam_data import SamData
|
|
27 |
from open_flamingo.eval.models.utils import unwrap_model
|
28 |
from train.utils import str2bool
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
from slots.DINOSAUR import DINOSAURpp
|
31 |
import matplotlib.pyplot as plt
|
32 |
from einops import rearrange, repeat
|
@@ -393,7 +405,7 @@ def train_one_epoch_slots(
|
|
393 |
if step_total >= args.steps:
|
394 |
break
|
395 |
|
396 |
-
torch.cuda.empty_cache()
|
397 |
|
398 |
|
399 |
|
|
|
27 |
from open_flamingo.eval.models.utils import unwrap_model
|
28 |
from train.utils import str2bool
|
29 |
|
30 |
+
RANDOM_SEED = 42 # any random number
|
31 |
+
def set_seed(seed):
|
32 |
+
random.seed(seed)
|
33 |
+
np.random.seed(seed)
|
34 |
+
torch.manual_seed(seed) # CPU
|
35 |
+
torch.cuda.manual_seed(seed) # GPU
|
36 |
+
torch.cuda.manual_seed_all(seed) # All GPU
|
37 |
+
os.environ['PYTHONHASHSEED'] = str(seed) # 禁止hash随机化
|
38 |
+
torch.backends.cudnn.deterministic = True # 确保每次返回的卷积算法是确定的
|
39 |
+
torch.backends.cudnn.benchmark = False # True的话会自动寻找最适合当前配置的高效算法,来达到优化运行效率的问题。False保证实验结果可复现
|
40 |
+
set_seed(RANDOM_SEED)
|
41 |
+
|
42 |
from slots.DINOSAUR import DINOSAURpp
|
43 |
import matplotlib.pyplot as plt
|
44 |
from einops import rearrange, repeat
|
|
|
405 |
if step_total >= args.steps:
|
406 |
break
|
407 |
|
408 |
+
# torch.cuda.empty_cache()
|
409 |
|
410 |
|
411 |
|