# Copyright (c) 2019-present, Francesco Croce # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import time import torch from autoattack.fab_projections import projection_linf, projection_l2,\ projection_l1 DEFAULT_EPS_DICT_BY_NORM = {'Linf': .3, 'L2': 1., 'L1': 5.0} class FABAttack(): """ Fast Adaptive Boundary Attack (Linf, L2, L1) https://arxiv.org/abs/1907.02044 :param norm: Lp-norm to minimize ('Linf', 'L2', 'L1' supported) :param n_restarts: number of random restarts :param n_iter: number of iterations :param eps: epsilon for the random restarts :param alpha_max: alpha_max :param eta: overshooting :param beta: backward step """ def __init__( self, norm='Linf', n_restarts=1, n_iter=100, eps=None, alpha_max=0.1, eta=1.05, beta=0.9, loss_fn=None, verbose=False, seed=0, targeted=False, device=None, n_target_classes=9): """ FAB-attack implementation in pytorch """ self.norm = norm self.n_restarts = n_restarts self.n_iter = n_iter self.eps = eps if eps is not None else DEFAULT_EPS_DICT_BY_NORM[norm] self.alpha_max = alpha_max self.eta = eta self.beta = beta self.targeted = targeted self.verbose = verbose self.seed = seed self.target_class = None self.device = device self.n_target_classes = n_target_classes def check_shape(self, x): return x if len(x.shape) > 0 else x.unsqueeze(0) def _predict_fn(self, x): raise NotImplementedError("Virtual function.") def _get_predicted_label(self, x): raise NotImplementedError("Virtual function.") def get_diff_logits_grads_batch(self, imgs, la): raise NotImplementedError("Virtual function.") def get_diff_logits_grads_batch_targeted(self, imgs, la, la_target): raise NotImplementedError("Virtual function.") def attack_single_run(self, x, y=None, use_rand_start=False, is_targeted=False): """ :param x: clean images :param y: clean labels, if None we use the predicted labels :param is_targeted True if we ise targeted version. Targeted class is assigned by `self.target_class` """ if self.device is None: self.device = x.device self.orig_dim = list(x.shape[1:]) self.ndims = len(self.orig_dim) x = x.detach().clone().float().to(self.device) #assert next(self.predict.parameters()).device == x.device y_pred = self._get_predicted_label(x) if y is None: y = y_pred.detach().clone().long().to(self.device) else: y = y.detach().clone().long().to(self.device) pred = y_pred == y corr_classified = pred.float().sum() if self.verbose: print('Clean accuracy: {:.2%}'.format(pred.float().mean())) if pred.sum() == 0: return x pred = self.check_shape(pred.nonzero().squeeze()) if is_targeted: output = self._predict_fn(x) la_target = output.sort(dim=-1)[1][:, -self.target_class] la_target2 = la_target[pred].detach().clone() startt = time.time() # runs the attack only on correctly classified points im2 = x[pred].detach().clone() la2 = y[pred].detach().clone() if len(im2.shape) == self.ndims: im2 = im2.unsqueeze(0) bs = im2.shape[0] u1 = torch.arange(bs) adv = im2.clone() adv_c = x.clone() res2 = 1e10 * torch.ones([bs]).to(self.device) x1 = im2.clone() x0 = im2.clone().reshape([bs, -1]) if use_rand_start: if self.norm == 'Linf': t = 2 * torch.rand(x1.shape).to(self.device) - 1 x1 = im2 + (torch.min(res2, self.eps * torch.ones(res2.shape) .to(self.device) ).reshape([-1, *[1]*self.ndims]) ) * t / (t.reshape([t.shape[0], -1]).abs() .max(dim=1, keepdim=True)[0] .reshape([-1, *[1]*self.ndims])) * .5 elif self.norm == 'L2': t = torch.randn(x1.shape).to(self.device) x1 = im2 + (torch.min(res2, self.eps * torch.ones(res2.shape) .to(self.device) ).reshape([-1, *[1]*self.ndims]) ) * t / ((t ** 2) .view(t.shape[0], -1) .sum(dim=-1) .sqrt() .view(t.shape[0], *[1]*self.ndims)) * .5 elif self.norm == 'L1': t = torch.randn(x1.shape).to(self.device) x1 = im2 + (torch.min(res2, self.eps * torch.ones(res2.shape) .to(self.device) ).reshape([-1, *[1]*self.ndims]) ) * t / (t.abs().view(t.shape[0], -1) .sum(dim=-1) .view(t.shape[0], *[1]*self.ndims)) / 2 x1 = x1.clamp(0.0, 1.0) counter_iter = 0 while counter_iter < self.n_iter: with torch.no_grad(): if is_targeted: df, dg = self.get_diff_logits_grads_batch_targeted(x1, la2, la_target2) else: df, dg = self.get_diff_logits_grads_batch(x1, la2) if self.norm == 'Linf': dist1 = df.abs() / (1e-12 + dg.abs() .reshape(dg.shape[0], dg.shape[1], -1) .sum(dim=-1)) elif self.norm == 'L2': dist1 = df.abs() / (1e-12 + (dg ** 2) .reshape(dg.shape[0], dg.shape[1], -1) .sum(dim=-1).sqrt()) elif self.norm == 'L1': dist1 = df.abs() / (1e-12 + dg.abs().reshape( [df.shape[0], df.shape[1], -1]).max(dim=2)[0]) else: raise ValueError('norm not supported') ind = dist1.min(dim=1)[1] dg2 = dg[u1, ind] b = (- df[u1, ind] + (dg2 * x1).reshape(x1.shape[0], -1) .sum(dim=-1)) w = dg2.reshape([bs, -1]) if self.norm == 'Linf': d3 = projection_linf( torch.cat((x1.reshape([bs, -1]), x0), 0), torch.cat((w, w), 0), torch.cat((b, b), 0)) elif self.norm == 'L2': d3 = projection_l2( torch.cat((x1.reshape([bs, -1]), x0), 0), torch.cat((w, w), 0), torch.cat((b, b), 0)) elif self.norm == 'L1': d3 = projection_l1( torch.cat((x1.reshape([bs, -1]), x0), 0), torch.cat((w, w), 0), torch.cat((b, b), 0)) d1 = torch.reshape(d3[:bs], x1.shape) d2 = torch.reshape(d3[-bs:], x1.shape) if self.norm == 'Linf': a0 = d3.abs().max(dim=1, keepdim=True)[0]\ .view(-1, *[1]*self.ndims) elif self.norm == 'L2': a0 = (d3 ** 2).sum(dim=1, keepdim=True).sqrt()\ .view(-1, *[1]*self.ndims) elif self.norm == 'L1': a0 = d3.abs().sum(dim=1, keepdim=True)\ .view(-1, *[1]*self.ndims) a0 = torch.max(a0, 1e-8 * torch.ones( a0.shape).to(self.device)) a1 = a0[:bs] a2 = a0[-bs:] alpha = torch.min(torch.max(a1 / (a1 + a2), torch.zeros(a1.shape) .to(self.device)), self.alpha_max * torch.ones(a1.shape) .to(self.device)) x1 = ((x1 + self.eta * d1) * (1 - alpha) + (im2 + d2 * self.eta) * alpha).clamp(0.0, 1.0) is_adv = self._get_predicted_label(x1) != la2 if is_adv.sum() > 0: ind_adv = is_adv.nonzero().squeeze() ind_adv = self.check_shape(ind_adv) if self.norm == 'Linf': t = (x1[ind_adv] - im2[ind_adv]).reshape( [ind_adv.shape[0], -1]).abs().max(dim=1)[0] elif self.norm == 'L2': t = ((x1[ind_adv] - im2[ind_adv]) ** 2)\ .reshape(ind_adv.shape[0], -1).sum(dim=-1).sqrt() elif self.norm == 'L1': t = (x1[ind_adv] - im2[ind_adv])\ .abs().reshape(ind_adv.shape[0], -1).sum(dim=-1) adv[ind_adv] = x1[ind_adv] * (t < res2[ind_adv]).\ float().reshape([-1, *[1]*self.ndims]) + adv[ind_adv]\ * (t >= res2[ind_adv]).float().reshape( [-1, *[1]*self.ndims]) res2[ind_adv] = t * (t < res2[ind_adv]).float()\ + res2[ind_adv] * (t >= res2[ind_adv]).float() x1[ind_adv] = im2[ind_adv] + ( x1[ind_adv] - im2[ind_adv]) * self.beta counter_iter += 1 ind_succ = res2 < 1e10 if self.verbose: print('success rate: {:.0f}/{:.0f}' .format(ind_succ.float().sum(), corr_classified) + ' (on correctly classified points) in {:.1f} s' .format(time.time() - startt)) ind_succ = self.check_shape(ind_succ.nonzero().squeeze()) adv_c[pred[ind_succ]] = adv[ind_succ].clone() return adv_c def perturb(self, x, y): if self.device is None: self.device = x.device adv = x.clone() with torch.no_grad(): acc = self._predict_fn(x).max(1)[1] == y startt = time.time() torch.random.manual_seed(self.seed) torch.cuda.random.manual_seed(self.seed) if not self.targeted: for counter in range(self.n_restarts): ind_to_fool = acc.nonzero().squeeze() if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) if ind_to_fool.numel() != 0: x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone() adv_curr = self.attack_single_run(x_to_fool, y_to_fool, use_rand_start=(counter > 0), is_targeted=False) acc_curr = self._predict_fn(adv_curr).max(1)[1] == y_to_fool if self.norm == 'Linf': res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).max(1)[0] elif self.norm == 'L2': res = ((x_to_fool - adv_curr) ** 2).reshape(x_to_fool.shape[0], -1).sum(dim=-1).sqrt() elif self.norm == 'L1': res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).sum(-1) acc_curr = torch.max(acc_curr, res > self.eps) ind_curr = (acc_curr == 0).nonzero().squeeze() acc[ind_to_fool[ind_curr]] = 0 adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() if self.verbose: print('restart {} - robust accuracy: {:.2%} at eps = {:.5f} - cum. time: {:.1f} s'.format( counter, acc.float().mean(), self.eps, time.time() - startt)) else: for target_class in range(2, self.n_target_classes + 2): self.target_class = target_class for counter in range(self.n_restarts): ind_to_fool = acc.nonzero().squeeze() if len(ind_to_fool.shape) == 0: ind_to_fool = ind_to_fool.unsqueeze(0) if ind_to_fool.numel() != 0: x_to_fool, y_to_fool = x[ind_to_fool].clone(), y[ind_to_fool].clone() adv_curr = self.attack_single_run(x_to_fool, y_to_fool, use_rand_start=(counter > 0), is_targeted=True) acc_curr = self._predict_fn(adv_curr).max(1)[1] == y_to_fool if self.norm == 'Linf': res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).max(1)[0] elif self.norm == 'L2': res = ((x_to_fool - adv_curr) ** 2).reshape(x_to_fool.shape[0], -1).sum(dim=-1).sqrt() elif self.norm == 'L1': res = (x_to_fool - adv_curr).abs().reshape(x_to_fool.shape[0], -1).sum(-1) acc_curr = torch.max(acc_curr, res > self.eps) ind_curr = (acc_curr == 0).nonzero().squeeze() acc[ind_to_fool[ind_curr]] = 0 adv[ind_to_fool[ind_curr]] = adv_curr[ind_curr].clone() if self.verbose: print('restart {} - target_class {} - robust accuracy: {:.2%} at eps = {:.5f} - cum. time: {:.1f} s'.format( counter, self.target_class, acc.float().mean(), self.eps, time.time() - startt)) return adv