Upload adversarial_training_clip_with_object_token.py
Browse files
train/adversarial_training_clip_with_object_token.py
CHANGED
@@ -31,6 +31,8 @@ import argparse
|
|
31 |
from slots.DINOSAUR import DINOSAURpp
|
32 |
import matplotlib.pyplot as plt
|
33 |
from einops import rearrange, repeat
|
|
|
|
|
34 |
|
35 |
parser = argparse.ArgumentParser()
|
36 |
parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32')
|
@@ -129,9 +131,42 @@ def main(args):
|
|
129 |
####################################################### get slot-attention model #########################################################
|
130 |
cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
|
131 |
model_slots = DINOSAURpp(cfg_dict)
|
132 |
-
proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
if args.optimizer_state != '':
|
134 |
proj_head.load_state_dict(torch.load(args.pretrained_proj_head))
|
|
|
|
|
135 |
|
136 |
|
137 |
|
@@ -338,7 +373,37 @@ def train_one_epoch(
|
|
338 |
embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize)
|
339 |
reconstruction, slots, masks, x_dinov2 = model_slots(patches_orig) # (B, token, 768)
|
340 |
|
341 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
342 |
|
343 |
# loss for the attack
|
344 |
loss_inner_wrapper = ComputeLossWrapper(
|
|
|
31 |
from slots.DINOSAUR import DINOSAURpp
|
32 |
import matplotlib.pyplot as plt
|
33 |
from einops import rearrange, repeat
|
34 |
+
from IPG.IPG_arch import IPG
|
35 |
+
|
36 |
|
37 |
parser = argparse.ArgumentParser()
|
38 |
parser.add_argument('--clip_model_name', type=str, default='ViT-L-14', help='ViT-L-14, ViT-B-32')
|
|
|
131 |
####################################################### get slot-attention model #########################################################
|
132 |
cfg_dict = {'slot_dim': 256, 'num_slots': 10, 'token_num': 256, 'ISA': False, 'slot_att_iter': 3, 'query_opt': False}
|
133 |
model_slots = DINOSAURpp(cfg_dict)
|
134 |
+
# proj_head = torch.nn.Linear(256, 1024) # slot-num to slot-num
|
135 |
+
# add for IPG
|
136 |
+
upscale = 1
|
137 |
+
height = (8 // upscale)
|
138 |
+
width = (8 // upscale)
|
139 |
+
proj_head = IPG(
|
140 |
+
upscale=upscale,
|
141 |
+
in_chans=64,
|
142 |
+
out_chans=64,
|
143 |
+
img_size=(height, width),
|
144 |
+
window_size=2,
|
145 |
+
img_range=1.,
|
146 |
+
depths=[2, 2],
|
147 |
+
embed_dim=256,
|
148 |
+
num_heads=[8, 8],
|
149 |
+
mlp_ratio=4,
|
150 |
+
upsampler='sam',
|
151 |
+
resi_connection='1conv',
|
152 |
+
graph_flags=[1, 1],
|
153 |
+
stage_spec=[['GN', 'GS'], ['GN', 'GS']],
|
154 |
+
dist_type='cossim',
|
155 |
+
top_k=256,
|
156 |
+
head_wise=0,
|
157 |
+
sample_size=4,
|
158 |
+
graph_switch=1,
|
159 |
+
flex_type='interdiff_plain',
|
160 |
+
FFNtype='basic-dwconv3',
|
161 |
+
conv_scale=0,
|
162 |
+
conv_type='dwconv3-gelu-conv1-ca',
|
163 |
+
diff_scales=[1.5, 1.5],
|
164 |
+
fast_graph=1
|
165 |
+
)
|
166 |
if args.optimizer_state != '':
|
167 |
proj_head.load_state_dict(torch.load(args.pretrained_proj_head))
|
168 |
+
if args.slots_ckp != '':
|
169 |
+
model_slots.load_state_dict(torch.load(args.slots_ckp))
|
170 |
|
171 |
|
172 |
|
|
|
373 |
embedding_orig, patches_orig = model_orig(vision=data, output_normalize=args.output_normalize)
|
374 |
reconstruction, slots, masks, x_dinov2 = model_slots(patches_orig) # (B, token, 768)
|
375 |
|
376 |
+
|
377 |
+
|
378 |
+
with torch.no_grad():
|
379 |
+
b, hw, c = reconstruction.shape
|
380 |
+
h = int(pow(hw, 0.5))
|
381 |
+
w = h
|
382 |
+
k = masks.size(1)
|
383 |
+
reconstruction = rearrange(reconstruction, 'b (h w) c -> b c h w', h=h, w=w)
|
384 |
+
masks = rearrange(masks, 'b k (h w) -> b k h w', h=h, w=w)
|
385 |
+
masks_recon_feat = torch.einsum('b k h w, b c h w -> b k c', masks, reconstruction)
|
386 |
+
masks_recon_feat = masks_recon_feat.repeat(1, k, 1)
|
387 |
+
b, hw, c = masks_recon_feat.shape
|
388 |
+
h = int(pow(hw, 0.5))
|
389 |
+
w = h
|
390 |
+
sim = F.cosine_similarity(masks_recon_feat[:,None, :, :], masks_recon_feat[:,:, None, :], dim=-1).mean(-1)
|
391 |
+
sim = rearrange(sim, 'b (h w) -> b h w', h=h, w=w)
|
392 |
+
|
393 |
+
top_values, top_indices = torch.topk(sim[:, 1], k-2)
|
394 |
+
maxsim_idx = torch.argmax(sim[:, 1], dim=-1)
|
395 |
+
top_indices_slos = top_indices.unsqueeze(-1).repeat(1,1,slots.size(-1))
|
396 |
+
top_indices_sim = top_indices.unsqueeze(-1).repeat(1,1,k-2)
|
397 |
+
|
398 |
+
h, w = k-2, k-2
|
399 |
+
slots = torch.gather(slots, dim=1, index=top_indices_slos)
|
400 |
+
sim = torch.gather(sim, dim=1, index=top_indices_sim)
|
401 |
+
slot_tokens = slots.repeat(1, k-2, 1)
|
402 |
+
slot_tokens = rearrange(slot_tokens, 'b (h w) c -> b c h w', h=h, w=w)
|
403 |
+
b, c, h, w = slot_tokens.shape
|
404 |
+
object_token = proj_head(slot_tokens, sim_matric=sim)
|
405 |
+
|
406 |
+
# object_token = proj_head(slots)
|
407 |
|
408 |
# loss for the attack
|
409 |
loss_inner_wrapper = ComputeLossWrapper(
|