amildravid4292 commited on
Commit
2804b00
·
verified ·
1 Parent(s): 38f9244

Update lora_w2w.py

Browse files
Files changed (1) hide show
  1. lora_w2w.py +11 -7
lora_w2w.py CHANGED
@@ -71,13 +71,13 @@ class LoRAModule(nn.Module):
71
  self.in_dim = org_module.in_features
72
  self.out_dim = org_module.out_features
73
 
74
- self.proj = proj.bfloat16()
75
- self.mean1 = mean[0:self.in_dim].bfloat16()
76
- self.mean2 = mean[self.in_dim:].bfloat16()
77
- self.std1 = std[0:self.in_dim].bfloat16()
78
- self.std2 = std[self.in_dim:].bfloat16()
79
- self.v1 = v[0:self.in_dim].bfloat16()
80
- self.v2 = v[self.in_dim: ].bfloat16()
81
 
82
  if type(alpha) == torch.Tensor:
83
  alpha = alpha.detach().numpy()
@@ -95,6 +95,10 @@ class LoRAModule(nn.Module):
95
  del self.org_module
96
 
97
  def forward(self, x):
 
 
 
 
98
  return self.org_forward(x) +\
99
  (x@(([email protected])*self.std1+self.mean1).T)@((([email protected])*self.std2+self.mean2))*self.multiplier*self.scale
100
 
 
71
  self.in_dim = org_module.in_features
72
  self.out_dim = org_module.out_features
73
 
74
+ self.proj = proj.bfloat16().cuda()
75
+ self.mean1 = mean[0:self.in_dim].bfloat16().cuda()
76
+ self.mean2 = mean[self.in_dim:].bfloat16().cuda()
77
+ self.std1 = std[0:self.in_dim].bfloat16().cuda()
78
+ self.std2 = std[self.in_dim:].bfloat16().cuda()
79
+ self.v1 = v[0:self.in_dim].bfloat16().cuda()
80
+ self.v2 = v[self.in_dim: ].bfloat16().cuda()
81
 
82
  if type(alpha) == torch.Tensor:
83
  alpha = alpha.detach().numpy()
 
95
  del self.org_module
96
 
97
  def forward(self, x):
98
+ print(self.proj.device)
99
+ print(self.v1.device)
100
+ print(self.mean1.device)
101
+ print(self.std1.device)
102
  return self.org_forward(x) +\
103
  (x@(([email protected])*self.std1+self.mean1).T)@((([email protected])*self.std2+self.mean2))*self.multiplier*self.scale
104