|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
from __future__ import unicode_literals |
|
|
|
import time |
|
|
|
import torch |
|
|
|
from autoattack.fab_projections import projection_linf, projection_l2,\ |
|
projection_l1 |
|
|
|
DEFAULT_EPS_DICT_BY_NORM = {'Linf': .3, 'L2': 1., 'L1': 5.0} |
|
|
|
|
|
class FABAttack(): |
|
""" |
|
Fast Adaptive Boundary Attack (Linf, L2, L1) |
|
https://arxiv.org/abs/1907.02044 |
|
|
|
:param norm: Lp-norm to minimize ('Linf', 'L2', 'L1' supported) |
|
:param n_restarts: number of random restarts |
|
:param n_iter: number of iterations |
|
:param eps: epsilon for the random restarts |
|
:param alpha_max: alpha_max |
|
:param eta: overshooting |
|
:param beta: backward step |
|
""" |
|
|
|
def __init__( |
|
self, |
|
norm='Linf', |
|
n_restarts=1, |
|
n_iter=100, |
|
eps=None, |
|
alpha_max=0.1, |
|
eta=1.05, |
|
beta=0.9, |
|
loss_fn=None, |
|
verbose=False, |
|
seed=0, |
|
targeted=False, |
|
device=None, |
|
n_target_classes=9): |
|
""" FAB-attack implementation in pytorch """ |
|
|
|
self.norm = norm |
|
self.n_restarts = n_restarts |
|
self.n_iter = n_iter |
|
self.eps = eps if eps is not None else DEFAULT_EPS_DICT_BY_NORM[norm] |
|
self.alpha_max = alpha_max |
|
self.eta = eta |
|
self.beta = beta |
|
self.targeted = targeted |
|
self.verbose = verbose |
|
self.seed = seed |
|
self.target_class = None |
|
self.device = device |
|
self.n_target_classes = n_target_classes |
|
|
|
def check_shape(self, x): |
|
return x if len(x.shape) > 0 else x.unsqueeze(0) |
|
|
|
def _predict_fn(self, x): |
|
raise NotImplementedError("Virtual function.") |
|
|
|
def _get_predicted_label(self, x): |
|
raise NotImplementedError("Virtual function.") |
|
|
|
def get_diff_logits_grads_batch(self, imgs, la): |
|
raise NotImplementedError("Virtual function.") |
|
|
|
def get_diff_logits_grads_batch_targeted(self, imgs, la, la_target): |
|
raise NotImplementedError("Virtual function.") |
|
|
|
def attack_single_run(self, x, y=None, use_rand_start=False, is_targeted=False): |
|
""" |
|
:param x: clean images |
|
:param y: clean labels, if None we use the predicted labels |
|
:param is_targeted True if we ise targeted version. Targeted class is assigned by `self.target_class` |
|
""" |
|
|
|
if self.device is None: |
|
self.device = x.device |
|
self.orig_dim = list(x.shape[1:]) |
|
self.ndims = len(self.orig_dim) |
|
|
|
x = x.detach().clone().float().to(self.device) |
|
|
|
|
|
y_pred = self._get_predicted_label(x) |
|
if y is None: |
|
y = y_pred.detach().clone().long().to(self.device) |
|
else: |
|
y = y.detach().clone().long().to(self.device) |
|
pred = y_pred == y |
|
corr_classified = pred.float().sum() |
|
if self.verbose: |
|
print('Clean accuracy: {:.2%}'.format(pred.float().mean())) |
|
if pred.sum() == 0: |
|
return x |
|
pred = self.check_shape(pred.nonzero().squeeze()) |
|
|
|
if is_targeted: |
|
output = self._predict_fn(x) |
|
la_target = output.sort(dim=-1)[1][:, -self.target_class] |
|
la_target2 = la_target[pred].detach().clone() |
|
|
|
startt = time.time() |
|
|
|
im2 = x[pred].detach().clone() |
|
la2 = y[pred].detach().clone() |
|
if len(im2.shape) == self.ndims: |
|
im2 = im2.unsqueeze(0) |
|
bs = im2.shape[0] |
|
u1 = torch.arange(bs) |
|
adv = im2.clone() |
|
adv_c = x.clone() |
|
res2 = 1e10 * torch.ones([bs]).to(self.device) |
|
x1 = im2.clone() |
|
x0 = im2.clone().reshape([bs, -1]) |
|
|
|
if use_rand_start: |
|
if self.norm == 'Linf': |
|
t = 2 * torch.rand(x1.shape).to(self.device) - 1 |
|
x1 = im2 + (torch.min(res2, |
|
self.eps * torch.ones(res2.shape) |
|
.to(self.device) |
|
).reshape([-1, *[1]*self.ndims]) |
|
) * t / (t.reshape([t.shape[0], -1]).abs() |
|
.max(dim=1, keepdim=True)[0] |
|
.reshape([-1, *[1]*self.ndims])) * .5 |
|
elif self.norm == 'L2': |
|
t = torch.randn(x1.shape).to(self.device) |
|
x1 = im2 + (torch.min(res2, |
|
self.eps * torch.ones(res2.shape) |
|
.to(self.device) |
|
).reshape([-1, *[1]*self.ndims]) |
|
) * t / ((t ** 2) |
|
.view(t.shape[0], -1) |
|
.sum(dim=-1) |
|
.sqrt() |
|
.view(t.shape[0], *[1]*self.ndims)) * .5 |
|
elif self.norm == 'L1': |
|
t = torch.randn(x1.shape).to(self.device) |
|
x1 = im2 + (torch.min(res2, |
|
self.eps * torch.ones(res2.shape) |
|
.to(self.device) |
|
).reshape([-1, *[1]*self.ndims]) |
|
) * t / (t.abs().view(t.shape[0], -1) |
|
.sum(dim=-1) |
|
.view(t.shape[0], *[1]*self.ndims)) / 2 |
|
|
|
x1 = x1.clamp(0.0, 1.0) |
|
|
|
counter_iter = 0 |
|
while counter_iter < self.n_iter: |
|
with torch.no_grad(): |
|
if is_targeted: |
|
df, dg = self.get_diff_logits_grads_batch_targeted(x1, la2, la_target2) |
|
else: |
|
df, dg = self.get_diff_logits_grads_batch(x1, la2) |
|
if self.norm == 'Linf': |
|
dist1 = df.abs() / (1e-12 + |
|
dg.abs() |
|
.reshape(dg.shape[0], dg.shape[1], -1) |
|
.sum(dim=-1)) |
|
elif self.norm == 'L2': |
|
dist1 = df.abs() / (1e-12 + (dg ** 2) |
|
.reshape(dg.shape[0], dg.shape[1], -1) |
|
.sum(dim=-1).sqrt()) |
|
elif self.norm == 'L1': |
|
dist1 = df.abs() / (1e-12 + dg.abs().reshape( |
|
[df.shape[0], df.shape[1], -1]).max(dim=2)[0]) |
|
else: |
|
raise ValueError('norm not supported') |
|
ind = dist1.min(dim=1)[1] |
|
dg2 = dg[u1, ind] |
|
b = (- df[u1, ind] + (dg2 * x1).reshape(x1.shape[0], -1) |
|
.sum(dim=-1)) |
|
w = dg2.reshape([bs, -1]) |
|
|
|
if self.norm == 'Linf': |
|
d3 = projection_linf( |
|
torch.cat((x1.reshape([bs, -1]), x0), 0), |
|
torch.cat((w, w), 0), |
|
torch.cat((b, b), 0)) |
|
elif self.norm == 'L2': |
|
d3 = projection_l2( |
|
torch.cat((x1.reshape([bs, -1]), x0), 0), |
|
torch.cat((w, w), 0), |
|
torch.cat((b, b), 0)) |
|
elif self.norm == 'L1': |
|
d3 = projection_l1( |
|
torch.cat((x1.reshape([bs, -1]), x0), 0), |
|
torch.cat((w, w), 0), |
|
torch.cat((b, b), 0)) |
|
d1 = torch.reshape(d3[:bs], x1.shape) |
|
d2 = torch.reshape(d3[-bs:], x1.shape) |
|
if self.norm == 'Linf': |
|
a0 = d3.abs().max(dim=1, keepdim=True)[0]\ |
|
.view(-1, *[1]*self.ndims) |
|
elif self.norm == 'L2': |
|
a0 = (d3 ** 2).sum(dim=1, keepdim=True).sqrt()\ |
|
.view(-1, *[1]*self.ndims) |
|
elif self.norm == 'L1': |
|
a0 = d3.abs().sum(dim=1, keepdim=True)\ |
|
.view(-1, *[1]*self.ndims) |
|
a0 = torch.max(a0, 1e-8 * torch.ones( |
|
a0.shape).to(self.device)) |
|
a1 = a0[:bs] |
|
a2 = a0[-bs:] |
|
alpha = torch.min(torch.max(a1 / (a1 + a2), |
|
torch.zeros(a1.shape) |
|
.to(self.device)), |
|
self.alpha_max * torch.ones(a1.shape) |
|
.to(self.device)) |
|
x1 = ((x1 + self.eta * d1) * (1 - alpha) + |
|
(im2 + d2 * self.eta) * alpha).clamp(0.0, 1.0) |
|
|
|
is_adv = self._get_predicted_label(x1) != la2 |
|
|
|
if is_adv.sum() > 0: |
|
ind_adv = is_adv.nonzero().squeeze() |
|
ind_adv = self.check_shape(ind_adv) |
|
if self.norm == 'Linf': |
|
t = (x1[ind_adv] - im2[ind_adv]).reshape( |
|
[ind_adv.shape[0], -1]).abs().max(dim=1)[0] |
|
elif self.norm == 'L2': |
|
t = ((x1[ind_adv] - im2[ind_adv]) ** 2)\ |
|
.reshape(ind_adv.shape[0], -1).sum(dim=-1).sqrt() |
|
elif self.norm == 'L1': |
|
t = (x1[ind_adv] - im2[ind_adv])\ |
|
.abs().reshape(ind_adv.shape[0], -1).sum(dim=-1) |
|
adv[ind_adv] = x1[ind_adv] * (t < res2[ind_adv]).\ |
|
float().reshape([-1, *[1]*self.ndims]) + adv[ind_adv]\ |
|
* (t >= res2[ind_adv]).float().reshape( |
|
[-1, *[1]*self.ndims]) |
|
res2[ind_adv] = t * (t < res2[ind_adv]).float()\ |
|
+ res2[ind_adv] * (t >= res2[ind_adv]).float() |
|
x1[ind_adv] = im2[ind_adv] + ( |
|
x1[ind_adv] - im2[ind_adv]) * self.beta |
|
|
|
counter_iter += 1 |
|
|
|
ind_succ = res2 < 1e10 |
|
if self.verbose: |
|
print('success rate: {:.0f}/{:.0f}' |
|
.format(ind_succ.float().sum(), corr_classified) + |
|
' (on correctly classified points) in {:.1f} s' |
|
.format(time.time() - startt)) |
|
|
|
ind_succ = self.check_shape(ind_succ.nonzero().squeeze()) |
|
adv_c[pred[ind_succ]] = adv[ind_succ].clone() |
|
|
|
return adv_c |
|
|
|
def perturb(self, x, y): |
|
if self.device is None: |
|
self.device = x.device |
|
adv = x.clone() |
|
with torch.no_grad(): |
|
acc = self._predict_fn(x).max(1)[1] == y |
|
|
|
startt = time.time() |
|
|
|
torch.random.manual_seed(self.seed) |
|
torch.cuda.random.manual_seed(self.seed) |
|
|
|
if not self.targeted: |
|
for counter in range(self.n_restarts): |
|
ind_to_fool = acc.nonzero().squeeze() |
|
if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) |
|
if ind_to_fool.numel() != 0: |
|
x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone() |
|
adv_curr = self.attack_single_run(x_to_fool, y_to_fool, use_rand_start=(counter > 0), is_targeted=False) |
|
|
|
acc_curr = self._predict_fn(adv_curr).max(1)[1] == y_to_fool |
|
if self.norm == 'Linf': |
|
res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).max(1)[0] |
|
elif self.norm == 'L2': |
|
res = ((x_to_fool - adv_curr) ** 2).reshape(x_to_fool.shape[0], -1).sum(dim=-1).sqrt() |
|
elif self.norm == 'L1': |
|
res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).sum(-1) |
|
acc_curr = torch.max(acc_curr, res > self.eps) |
|
|
|
ind_curr = (acc_curr == 0).nonzero().squeeze() |
|
acc[ind_to_fool[ind_curr]] = 0 |
|
adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() |
|
|
|
if self.verbose: |
|
print('restart {} - robust accuracy: {:.2%} at eps = {:.5f} - cum. time: {:.1f} s'.format( |
|
counter, acc.float().mean(), self.eps, time.time() - startt)) |
|
|
|
else: |
|
for target_class in range(2, self.n_target_classes + 2): |
|
self.target_class = target_class |
|
for counter in range(self.n_restarts): |
|
ind_to_fool = acc.nonzero().squeeze() |
|
if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) |
|
if ind_to_fool.numel() != 0: |
|
x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone() |
|
adv_curr = self.attack_single_run(x_to_fool, y_to_fool, use_rand_start=(counter > 0), is_targeted=True) |
|
|
|
acc_curr = self._predict_fn(adv_curr).max(1)[1] == y_to_fool |
|
if self.norm == 'Linf': |
|
res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).max(1)[0] |
|
elif self.norm == 'L2': |
|
res = ((x_to_fool - adv_curr) ** 2).reshape(x_to_fool.shape[0], -1).sum(dim=-1).sqrt() |
|
elif self.norm == 'L1': |
|
res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).sum(-1) |
|
acc_curr = torch.max(acc_curr, res > self.eps) |
|
|
|
ind_curr = (acc_curr == 0).nonzero().squeeze() |
|
acc[ind_to_fool[ind_curr]] = 0 |
|
adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() |
|
|
|
if self.verbose: |
|
print('restart {} - target_class {} - robust accuracy: {:.2%} at eps = {:.5f} - cum. time: {:.1f} s'.format( |
|
counter, self.target_class, acc.float().mean(), self.eps, time.time() - startt)) |
|
|
|
return adv |
|
|