File size: 5,206 Bytes
e1aaaac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
import torch
import warnings
import math
import sys
from autoattack.other_utils import L2_norm
funcs = {'grad': 0,
'backward': 0,
#'enable_grad': 0
'_make_grads': 0,
}
checks_doc_path = 'flags_doc.md'
def check_randomized(model, x, y, bs=250, n=5, alpha=1e-4, logger=None):
acc = []
corrcl = []
outputs = []
with torch.no_grad():
for _ in range(n):
output = model(x)
corrcl_curr = (output.max(1)[1] == y).sum().item()
corrcl.append(corrcl_curr)
outputs.append(output / (L2_norm(output, keepdim=True) + 1e-10))
acc = [c != corrcl_curr for c in corrcl]
max_diff = 0.
for c in range(n - 1):
for e in range(c + 1, n):
diff = L2_norm(outputs[c] - outputs[e])
max_diff = max(max_diff, diff.max().item())
#print(diff.max().item(), max_diff)
if any(acc) or max_diff > alpha:
msg = 'it seems to be a randomized defense! Please use version="rand".' + \
f' See {checks_doc_path} for details.'
if logger is None:
warnings.warn(Warning(msg))
else:
logger.log(f'Warning: {msg}')
def check_range_output(model, x, alpha=1e-5, logger=None):
with torch.no_grad():
output = model(x)
fl = [output.max() < 1. + alpha, output.min() > -alpha,
((output.sum(-1) - 1.).abs() < alpha).all()]
if all(fl):
msg = 'it seems that the output is a probability distribution,' +\
' please be sure that the logits are used!' + \
f' See {checks_doc_path} for details.'
if logger is None:
warnings.warn(Warning(msg))
else:
logger.log(f'Warning: {msg}')
return output.shape[-1]
def check_zero_gradients(grad, logger=None):
z = grad.view(grad.shape[0], -1).abs().sum(-1)
#print(grad[0, :10])
if (z == 0).any():
msg = f'there are {(z == 0).sum()} points with zero gradient!' + \
' This might lead to unreliable evaluation with gradient-based attacks.' + \
f' See {checks_doc_path} for details.'
if logger is None:
warnings.warn(Warning(msg))
else:
logger.log(f'Warning: {msg}')
def check_square_sr(acc_dict, alpha=.002, logger=None):
if 'square' in acc_dict.keys() and len(acc_dict) > 2:
acc = min([v for k, v in acc_dict.items() if k != 'square'])
if acc_dict['square'] < acc - alpha:
msg = 'Square Attack has decreased the robust accuracy of' + \
f' {acc - acc_dict["square"]:.2%}.' + \
' This might indicate that the robustness evaluation using' +\
' AutoAttack is unreliable. Consider running Square' +\
' Attack with more iterations and restarts or an adaptive attack.' + \
f' See {checks_doc_path} for details.'
if logger is None:
warnings.warn(Warning(msg))
else:
logger.log(f'Warning: {msg}')
''' from https://stackoverflow.com/questions/26119521/counting-function-calls-python '''
def tracefunc(frame, event, args):
if event == 'call' and frame.f_code.co_name in funcs.keys():
funcs[frame.f_code.co_name] += 1
def check_dynamic(model, x, is_tf_model=False, logger=None):
if is_tf_model:
msg = 'the check for dynamic defenses is not currently supported'
else:
msg = None
sys.settrace(tracefunc)
model(x)
sys.settrace(None)
#for k, v in funcs.items():
# print(k, v)
if any([c > 0 for c in funcs.values()]):
msg = 'it seems to be a dynamic defense! The evaluation' + \
' with AutoAttack might be insufficient.' + \
f' See {checks_doc_path} for details.'
if not msg is None:
if logger is None:
warnings.warn(Warning(msg))
else:
logger.log(f'Warning: {msg}')
#sys.settrace(None)
def check_n_classes(n_cls, attacks_to_run, apgd_targets, fab_targets,
logger=None):
msg = None
if 'apgd-dlr' in attacks_to_run or 'apgd-t' in attacks_to_run:
if n_cls <= 2:
msg = f'with only {n_cls} classes it is not possible to use the DLR loss!'
elif n_cls == 3:
msg = f'with only {n_cls} classes it is not possible to use the targeted DLR loss!'
elif 'apgd-t' in attacks_to_run and \
apgd_targets + 1 > n_cls:
msg = f'it seems that more target classes ({apgd_targets})' + \
f' than possible ({n_cls - 1}) are used in {"apgd-t".upper()}!'
if 'fab-t' in attacks_to_run and fab_targets + 1 > n_cls:
if msg is None:
msg = f'it seems that more target classes ({apgd_targets})' + \
f' than possible ({n_cls - 1}) are used in FAB-T!'
else:
msg += f' Also, it seems that too many target classes ({apgd_targets})' + \
f' are used in {"fab-t".upper()} ({n_cls - 1} possible)!'
if not msg is None:
if logger is None:
warnings.warn(Warning(msg))
else:
logger.log(f'Warning: {msg}')
|