|
import argparse |
|
import os |
|
|
|
import h5py |
|
|
|
from misc_functions import * |
|
from torchvision.datasets import ImageNet |
|
from tqdm import tqdm |
|
from ViT_explanation_generator import LRP, Baselines |
|
from ViT_LRP import vit_base_patch16_224 as vit_LRP |
|
from ViT_new import vit_base_patch16_224 |
|
from ViT_orig_LRP import vit_base_patch16_224 as vit_orig_LRP |
|
|
|
|
|
def normalize(tensor, mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]): |
|
dtype = tensor.dtype |
|
mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) |
|
std = torch.as_tensor(std, dtype=dtype, device=tensor.device) |
|
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None]) |
|
return tensor |
|
|
|
|
|
def compute_saliency_and_save(args): |
|
first = True |
|
with h5py.File(os.path.join(args.method_dir, "results.hdf5"), "a") as f: |
|
data_cam = f.create_dataset( |
|
"vis", |
|
(1, 1, 224, 224), |
|
maxshape=(None, 1, 224, 224), |
|
dtype=np.float32, |
|
compression="gzip", |
|
) |
|
data_image = f.create_dataset( |
|
"image", |
|
(1, 3, 224, 224), |
|
maxshape=(None, 3, 224, 224), |
|
dtype=np.float32, |
|
compression="gzip", |
|
) |
|
data_target = f.create_dataset( |
|
"target", (1,), maxshape=(None,), dtype=np.int32, compression="gzip" |
|
) |
|
for batch_idx, (data, target) in enumerate(tqdm(sample_loader)): |
|
if first: |
|
first = False |
|
data_cam.resize(data_cam.shape[0] + data.shape[0] - 1, axis=0) |
|
data_image.resize(data_image.shape[0] + data.shape[0] - 1, axis=0) |
|
data_target.resize(data_target.shape[0] + data.shape[0] - 1, axis=0) |
|
else: |
|
data_cam.resize(data_cam.shape[0] + data.shape[0], axis=0) |
|
data_image.resize(data_image.shape[0] + data.shape[0], axis=0) |
|
data_target.resize(data_target.shape[0] + data.shape[0], axis=0) |
|
|
|
|
|
data_image[-data.shape[0] :] = data.data.cpu().numpy() |
|
data_target[-data.shape[0] :] = target.data.cpu().numpy() |
|
|
|
target = target.to(device) |
|
|
|
data = normalize(data) |
|
data = data.to(device) |
|
data.requires_grad_() |
|
|
|
index = None |
|
if args.vis_class == "target": |
|
index = target |
|
|
|
if args.method == "rollout": |
|
Res = baselines.generate_rollout(data, start_layer=1).reshape( |
|
data.shape[0], 1, 14, 14 |
|
) |
|
|
|
|
|
elif args.method == "lrp": |
|
Res = lrp.generate_LRP(data, start_layer=1, index=index).reshape( |
|
data.shape[0], 1, 14, 14 |
|
) |
|
|
|
|
|
elif args.method == "transformer_attribution": |
|
Res = lrp.generate_LRP( |
|
data, start_layer=1, method="grad", index=index |
|
).reshape(data.shape[0], 1, 14, 14) |
|
|
|
|
|
elif args.method == "full_lrp": |
|
Res = orig_lrp.generate_LRP(data, method="full", index=index).reshape( |
|
data.shape[0], 1, 224, 224 |
|
) |
|
|
|
|
|
elif args.method == "lrp_last_layer": |
|
Res = orig_lrp.generate_LRP( |
|
data, method="last_layer", is_ablation=args.is_ablation, index=index |
|
).reshape(data.shape[0], 1, 14, 14) |
|
|
|
|
|
elif args.method == "attn_last_layer": |
|
Res = lrp.generate_LRP( |
|
data, method="last_layer_attn", is_ablation=args.is_ablation |
|
).reshape(data.shape[0], 1, 14, 14) |
|
|
|
elif args.method == "attn_gradcam": |
|
Res = baselines.generate_cam_attn(data, index=index).reshape( |
|
data.shape[0], 1, 14, 14 |
|
) |
|
|
|
if args.method != "full_lrp" and args.method != "input_grads": |
|
Res = torch.nn.functional.interpolate( |
|
Res, scale_factor=16, mode="bilinear" |
|
).cuda() |
|
Res = (Res - Res.min()) / (Res.max() - Res.min()) |
|
|
|
data_cam[-data.shape[0] :] = Res.data.cpu().numpy() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Train a segmentation") |
|
parser.add_argument("--batch-size", type=int, default=1, help="") |
|
parser.add_argument( |
|
"--method", |
|
type=str, |
|
default="grad_rollout", |
|
choices=[ |
|
"rollout", |
|
"lrp", |
|
"transformer_attribution", |
|
"full_lrp", |
|
"lrp_last_layer", |
|
"attn_last_layer", |
|
"attn_gradcam", |
|
], |
|
help="", |
|
) |
|
parser.add_argument("--lmd", type=float, default=10, help="") |
|
parser.add_argument( |
|
"--vis-class", |
|
type=str, |
|
default="top", |
|
choices=["top", "target", "index"], |
|
help="", |
|
) |
|
parser.add_argument("--class-id", type=int, default=0, help="") |
|
parser.add_argument("--cls-agn", action="store_true", default=False, help="") |
|
parser.add_argument("--no-ia", action="store_true", default=False, help="") |
|
parser.add_argument("--no-fx", action="store_true", default=False, help="") |
|
parser.add_argument("--no-fgx", action="store_true", default=False, help="") |
|
parser.add_argument("--no-m", action="store_true", default=False, help="") |
|
parser.add_argument("--no-reg", action="store_true", default=False, help="") |
|
parser.add_argument("--is-ablation", type=bool, default=False, help="") |
|
parser.add_argument("--imagenet-validation-path", type=str, required=True, help="") |
|
args = parser.parse_args() |
|
|
|
|
|
PATH = os.path.dirname(os.path.abspath(__file__)) + "/" |
|
os.makedirs(os.path.join(PATH, "visualizations"), exist_ok=True) |
|
|
|
try: |
|
os.remove( |
|
os.path.join( |
|
PATH, |
|
"visualizations/{}/{}/results.hdf5".format(args.method, args.vis_class), |
|
) |
|
) |
|
except OSError: |
|
pass |
|
|
|
os.makedirs( |
|
os.path.join(PATH, "visualizations/{}".format(args.method)), exist_ok=True |
|
) |
|
if args.vis_class == "index": |
|
os.makedirs( |
|
os.path.join( |
|
PATH, |
|
"visualizations/{}/{}_{}".format( |
|
args.method, args.vis_class, args.class_id |
|
), |
|
), |
|
exist_ok=True, |
|
) |
|
args.method_dir = os.path.join( |
|
PATH, |
|
"visualizations/{}/{}_{}".format( |
|
args.method, args.vis_class, args.class_id |
|
), |
|
) |
|
else: |
|
ablation_fold = "ablation" if args.is_ablation else "not_ablation" |
|
os.makedirs( |
|
os.path.join( |
|
PATH, |
|
"visualizations/{}/{}/{}".format( |
|
args.method, args.vis_class, ablation_fold |
|
), |
|
), |
|
exist_ok=True, |
|
) |
|
args.method_dir = os.path.join( |
|
PATH, |
|
"visualizations/{}/{}/{}".format( |
|
args.method, args.vis_class, ablation_fold |
|
), |
|
) |
|
|
|
cuda = torch.cuda.is_available() |
|
device = torch.device("cuda" if cuda else "cpu") |
|
|
|
|
|
model = vit_base_patch16_224(pretrained=True).cuda() |
|
baselines = Baselines(model) |
|
|
|
|
|
model_LRP = vit_LRP(pretrained=True).cuda() |
|
model_LRP.eval() |
|
lrp = LRP(model_LRP) |
|
|
|
|
|
model_orig_LRP = vit_orig_LRP(pretrained=True).cuda() |
|
model_orig_LRP.eval() |
|
orig_lrp = LRP(model_orig_LRP) |
|
|
|
|
|
transform = transforms.Compose( |
|
[ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
] |
|
) |
|
|
|
imagenet_ds = ImageNet( |
|
args.imagenet_validation_path, split="val", download=False, transform=transform |
|
) |
|
sample_loader = torch.utils.data.DataLoader( |
|
imagenet_ds, batch_size=args.batch_size, shuffle=False, num_workers=4 |
|
) |
|
|
|
compute_saliency_and_save(args) |
|
|