comparative-explainability
/
Transformer-Explainability
/baselines
/ViT
/pertubation_eval_from_hdf5.py
import argparse | |
# from models.vgg import vgg19 | |
import glob | |
import os | |
import numpy as np | |
import torch | |
from dataset.expl_hdf5 import ImagenetResults | |
from tqdm import tqdm | |
# Import saliency methods and models | |
from ViT_explanation_generator import Baselines | |
from ViT_new import vit_base_patch16_224 | |
def normalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): | |
dtype = tensor.dtype | |
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) | |
std = torch.as_tensor(std, dtype=dtype, device=tensor.device) | |
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) | |
return tensor | |
def eval(args): | |
num_samples = 0 | |
num_correct_model = np.zeros( | |
( | |
len( | |
imagenet_ds, | |
) | |
) | |
) | |
dissimilarity_model = np.zeros( | |
( | |
len( | |
imagenet_ds, | |
) | |
) | |
) | |
model_index = 0 | |
if args.scale == "per": | |
base_size = 224 * 224 | |
perturbation_steps = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9] | |
elif args.scale == "100": | |
base_size = 100 | |
perturbation_steps = [5, 10, 15, 20, 25, 30, 35, 40, 45] | |
else: | |
raise Exception("scale not valid") | |
num_correct_pertub = np.zeros((9, len(imagenet_ds))) | |
dissimilarity_pertub = np.zeros((9, len(imagenet_ds))) | |
logit_diff_pertub = np.zeros((9, len(imagenet_ds))) | |
prob_diff_pertub = np.zeros((9, len(imagenet_ds))) | |
perturb_index = 0 | |
for batch_idx, (data, vis, target) in enumerate(tqdm(sample_loader)): | |
# Update the number of samples | |
num_samples += len(data) | |
data = data.to(device) | |
vis = vis.to(device) | |
target = target.to(device) | |
norm_data = normalize(data.clone()) | |
# Compute model accuracy | |
pred = model(norm_data) | |
pred_probabilities = torch.softmax(pred, dim=1) | |
pred_org_logit = pred.data.max(1, keepdim=True)[0].squeeze(1) | |
pred_org_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1) | |
pred_class = pred.data.max(1, keepdim=True)[1].squeeze(1) | |
tgt_pred = (target == pred_class).type(target.type()).data.cpu().numpy() | |
num_correct_model[model_index : model_index + len(tgt_pred)] = tgt_pred | |
probs = torch.softmax(pred, dim=1) | |
target_probs = torch.gather(probs, 1, target[:, None])[:, 0] | |
second_probs = probs.data.topk(2, dim=1)[0][:, 1] | |
temp = torch.log(target_probs / second_probs).data.cpu().numpy() | |
dissimilarity_model[model_index : model_index + len(temp)] = temp | |
if args.wrong: | |
wid = np.argwhere(tgt_pred == 0).flatten() | |
if len(wid) == 0: | |
continue | |
wid = torch.from_numpy(wid).to(vis.device) | |
vis = vis.index_select(0, wid) | |
data = data.index_select(0, wid) | |
target = target.index_select(0, wid) | |
# Save original shape | |
org_shape = data.shape | |
if args.neg: | |
vis = -vis | |
vis = vis.reshape(org_shape[0], -1) | |
for i in range(len(perturbation_steps)): | |
_data = data.clone() | |
_, idx = torch.topk(vis, int(base_size * perturbation_steps[i]), dim=-1) | |
idx = idx.unsqueeze(1).repeat(1, org_shape[1], 1) | |
_data = _data.reshape(org_shape[0], org_shape[1], -1) | |
_data = _data.scatter_(-1, idx, 0) | |
_data = _data.reshape(*org_shape) | |
_norm_data = normalize(_data) | |
out = model(_norm_data) | |
pred_probabilities = torch.softmax(out, dim=1) | |
pred_prob = pred_probabilities.data.max(1, keepdim=True)[0].squeeze(1) | |
diff = (pred_prob - pred_org_prob).data.cpu().numpy() | |
prob_diff_pertub[i, perturb_index : perturb_index + len(diff)] = diff | |
pred_logit = out.data.max(1, keepdim=True)[0].squeeze(1) | |
diff = (pred_logit - pred_org_logit).data.cpu().numpy() | |
logit_diff_pertub[i, perturb_index : perturb_index + len(diff)] = diff | |
target_class = out.data.max(1, keepdim=True)[1].squeeze(1) | |
temp = (target == target_class).type(target.type()).data.cpu().numpy() | |
num_correct_pertub[i, perturb_index : perturb_index + len(temp)] = temp | |
probs_pertub = torch.softmax(out, dim=1) | |
target_probs = torch.gather(probs_pertub, 1, target[:, None])[:, 0] | |
second_probs = probs_pertub.data.topk(2, dim=1)[0][:, 1] | |
temp = torch.log(target_probs / second_probs).data.cpu().numpy() | |
dissimilarity_pertub[i, perturb_index : perturb_index + len(temp)] = temp | |
model_index += len(target) | |
perturb_index += len(target) | |
np.save(os.path.join(args.experiment_dir, "model_hits.npy"), num_correct_model) | |
np.save( | |
os.path.join(args.experiment_dir, "model_dissimilarities.npy"), | |
dissimilarity_model, | |
) | |
np.save( | |
os.path.join(args.experiment_dir, "perturbations_hits.npy"), | |
num_correct_pertub[:, :perturb_index], | |
) | |
np.save( | |
os.path.join(args.experiment_dir, "perturbations_dissimilarities.npy"), | |
dissimilarity_pertub[:, :perturb_index], | |
) | |
np.save( | |
os.path.join(args.experiment_dir, "perturbations_logit_diff.npy"), | |
logit_diff_pertub[:, :perturb_index], | |
) | |
np.save( | |
os.path.join(args.experiment_dir, "perturbations_prob_diff.npy"), | |
prob_diff_pertub[:, :perturb_index], | |
) | |
print(np.mean(num_correct_model), np.std(num_correct_model)) | |
print(np.mean(dissimilarity_model), np.std(dissimilarity_model)) | |
print(perturbation_steps) | |
print(np.mean(num_correct_pertub, axis=1), np.std(num_correct_pertub, axis=1)) | |
print(np.mean(dissimilarity_pertub, axis=1), np.std(dissimilarity_pertub, axis=1)) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="Train a segmentation") | |
parser.add_argument("--batch-size", type=int, default=16, help="") | |
parser.add_argument("--neg", type=bool, default=True, help="") | |
parser.add_argument("--value", action="store_true", default=False, help="") | |
parser.add_argument( | |
"--scale", type=str, default="per", choices=["per", "100"], help="" | |
) | |
parser.add_argument( | |
"--method", | |
type=str, | |
default="grad_rollout", | |
choices=[ | |
"rollout", | |
"lrp", | |
"transformer_attribution", | |
"full_lrp", | |
"v_gradcam", | |
"lrp_last_layer", | |
"lrp_second_layer", | |
"gradcam", | |
"attn_last_layer", | |
"attn_gradcam", | |
"input_grads", | |
], | |
help="", | |
) | |
parser.add_argument( | |
"--vis-class", | |
type=str, | |
default="top", | |
choices=["top", "target", "index"], | |
help="", | |
) | |
parser.add_argument("--wrong", action="store_true", default=False, help="") | |
parser.add_argument("--class-id", type=int, default=0, help="") | |
parser.add_argument("--is-ablation", type=bool, default=False, help="") | |
args = parser.parse_args() | |
torch.multiprocessing.set_start_method("spawn") | |
# PATH variables | |
PATH = os.path.dirname(os.path.abspath(__file__)) + "/" | |
dataset = PATH + "dataset/" | |
os.makedirs(os.path.join(PATH, "experiments"), exist_ok=True) | |
os.makedirs(os.path.join(PATH, "experiments/perturbations"), exist_ok=True) | |
exp_name = args.method | |
exp_name += "_neg" if args.neg else "_pos" | |
print(exp_name) | |
if args.vis_class == "index": | |
args.runs_dir = os.path.join( | |
PATH, | |
"experiments/perturbations/{}/{}_{}".format( | |
exp_name, args.vis_class, args.class_id | |
), | |
) | |
else: | |
ablation_fold = "ablation" if args.is_ablation else "not_ablation" | |
args.runs_dir = os.path.join( | |
PATH, | |
"experiments/perturbations/{}/{}/{}".format( | |
exp_name, args.vis_class, ablation_fold | |
), | |
) | |
# args.runs_dir = os.path.join(PATH, 'experiments/perturbations/{}/{}'.format(exp_name, | |
# args.vis_class)) | |
if args.wrong: | |
args.runs_dir += "_wrong" | |
experiments = sorted(glob.glob(os.path.join(args.runs_dir, "experiment_*"))) | |
experiment_id = int(experiments[-1].split("_")[-1]) + 1 if experiments else 0 | |
args.experiment_dir = os.path.join( | |
args.runs_dir, "experiment_{}".format(str(experiment_id)) | |
) | |
os.makedirs(args.experiment_dir, exist_ok=True) | |
cuda = torch.cuda.is_available() | |
device = torch.device("cuda" if cuda else "cpu") | |
if args.vis_class == "index": | |
vis_method_dir = os.path.join( | |
PATH, | |
"visualizations/{}/{}_{}".format( | |
args.method, args.vis_class, args.class_id | |
), | |
) | |
else: | |
ablation_fold = "ablation" if args.is_ablation else "not_ablation" | |
vis_method_dir = os.path.join( | |
PATH, | |
"visualizations/{}/{}/{}".format( | |
args.method, args.vis_class, ablation_fold | |
), | |
) | |
# vis_method_dir = os.path.join(PATH, 'visualizations/{}/{}'.format(args.method, | |
# args.vis_class)) | |
# imagenet_ds = ImagenetResults('visualizations/{}'.format(args.method)) | |
imagenet_ds = ImagenetResults(vis_method_dir) | |
# Model | |
model = vit_base_patch16_224(pretrained=True).cuda() | |
model.eval() | |
save_path = PATH + "results/" | |
sample_loader = torch.utils.data.DataLoader( | |
imagenet_ds, batch_size=args.batch_size, num_workers=2, shuffle=False | |
) | |
eval(args) | |