File size: 2,513 Bytes
e1aaaac 317bfc1 e1aaaac 317bfc1 e1aaaac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import torch
from vlm_eval.attacks.utils import project_perturbation, normalize_grad
def pgd(
forward,
loss_fn,
data_clean,
targets,
norm,
eps,
iterations,
stepsize,
output_normalize,
perturbation=None,
mode='min',
momentum=0.9,
verbose=False,
need_OT=False
):
"""
Minimize or maximize given loss
"""
# make sure data is in image space
assert torch.max(data_clean) < 1. + 1e-6 and torch.min(data_clean) > -1e-6
if perturbation is None:
perturbation = torch.zeros_like(data_clean, requires_grad=True)
velocity = torch.zeros_like(data_clean)
for i in range(iterations):
perturbation.requires_grad = True
with torch.enable_grad():
out, patch_out = forward(data_clean + perturbation, output_normalize=output_normalize, need_OT=need_OT)
loss = loss_fn(out, targets)
if verbose:
print(f'[{i}] {loss.item():.5f}')
with torch.no_grad():
gradient = torch.autograd.grad(loss, perturbation)[0]
gradient = gradient
if gradient.isnan().any(): #
print(f'attention: nan in gradient ({gradient.isnan().sum()})') #
gradient[gradient.isnan()] = 0.
# normalize
gradient = normalize_grad(gradient, p=norm)
# momentum
velocity = momentum * velocity + gradient
velocity = normalize_grad(velocity, p=norm)
# update
if mode == 'min':
perturbation = perturbation - stepsize * velocity
elif mode == 'max':
perturbation = perturbation + stepsize * velocity
else:
raise ValueError(f'Unknown mode: {mode}')
# project
perturbation = project_perturbation(perturbation, eps, norm)
perturbation = torch.clamp(
data_clean + perturbation, 0, 1
) - data_clean # clamp to image space
assert not perturbation.isnan().any()
assert torch.max(data_clean + perturbation) < 1. + 1e-6 and torch.min(
data_clean + perturbation
) > -1e-6
# assert (ctorch.compute_norm(perturbation, p=self.norm) <= self.eps + 1e-6).all()
# todo return best perturbation
# problem is that model currently does not output expanded loss
return data_clean + perturbation.detach()
|