remove cuda().
Browse files
Transformer-Explainability/baselines/ViT/ViT_explanation_generator.py
CHANGED
@@ -50,7 +50,7 @@ class LRP:
|
|
50 |
one_hot[0, index] = 1
|
51 |
one_hot_vector = one_hot
|
52 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
53 |
-
one_hot = torch.sum(one_hot
|
54 |
|
55 |
self.model.zero_grad()
|
56 |
one_hot.backward(retain_graph=True)
|
@@ -70,14 +70,14 @@ class Baselines:
|
|
70 |
self.model.eval()
|
71 |
|
72 |
def generate_cam_attn(self, input, index=None):
|
73 |
-
output = self.model(input
|
74 |
if index == None:
|
75 |
index = np.argmax(output.cpu().data.numpy())
|
76 |
|
77 |
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
78 |
one_hot[0][index] = 1
|
79 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
80 |
-
one_hot = torch.sum(one_hot
|
81 |
|
82 |
self.model.zero_grad()
|
83 |
one_hot.backward(retain_graph=True)
|
|
|
50 |
one_hot[0, index] = 1
|
51 |
one_hot_vector = one_hot
|
52 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
53 |
+
one_hot = torch.sum(one_hot * output)
|
54 |
|
55 |
self.model.zero_grad()
|
56 |
one_hot.backward(retain_graph=True)
|
|
|
70 |
self.model.eval()
|
71 |
|
72 |
def generate_cam_attn(self, input, index=None):
|
73 |
+
output = self.model(input, register_hook=True)
|
74 |
if index == None:
|
75 |
index = np.argmax(output.cpu().data.numpy())
|
76 |
|
77 |
one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
|
78 |
one_hot[0][index] = 1
|
79 |
one_hot = torch.from_numpy(one_hot).requires_grad_(True)
|
80 |
+
one_hot = torch.sum(one_hot * output)
|
81 |
|
82 |
self.model.zero_grad()
|
83 |
one_hot.backward(retain_graph=True)
|