sayakpaul HF staff commited on
Commit
1d5ecf3
1 Parent(s): 8e6f824

remove cuda().

Browse files
Transformer-Explainability/baselines/ViT/ViT_explanation_generator.py CHANGED
@@ -50,7 +50,7 @@ class LRP:
50
  one_hot[0, index] = 1
51
  one_hot_vector = one_hot
52
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
53
- one_hot = torch.sum(one_hot.cuda() * output)
54
 
55
  self.model.zero_grad()
56
  one_hot.backward(retain_graph=True)
@@ -70,14 +70,14 @@ class Baselines:
70
  self.model.eval()
71
 
72
  def generate_cam_attn(self, input, index=None):
73
- output = self.model(input.cuda(), register_hook=True)
74
  if index == None:
75
  index = np.argmax(output.cpu().data.numpy())
76
 
77
  one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
78
  one_hot[0][index] = 1
79
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
80
- one_hot = torch.sum(one_hot.cuda() * output)
81
 
82
  self.model.zero_grad()
83
  one_hot.backward(retain_graph=True)
 
50
  one_hot[0, index] = 1
51
  one_hot_vector = one_hot
52
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
53
+ one_hot = torch.sum(one_hot * output)
54
 
55
  self.model.zero_grad()
56
  one_hot.backward(retain_graph=True)
 
70
  self.model.eval()
71
 
72
  def generate_cam_attn(self, input, index=None):
73
+ output = self.model(input, register_hook=True)
74
  if index == None:
75
  index = np.argmax(output.cpu().data.numpy())
76
 
77
  one_hot = np.zeros((1, output.size()[-1]), dtype=np.float32)
78
  one_hot[0][index] = 1
79
  one_hot = torch.from_numpy(one_hot).requires_grad_(True)
80
+ one_hot = torch.sum(one_hot * output)
81
 
82
  self.model.zero_grad()
83
  one_hot.backward(retain_graph=True)