yagizdevre commited on
Commit
1177ff0
·
1 Parent(s): 1f5f496
Files changed (3) hide show
  1. __pycache__/layers.cpython-312.pyc +0 -0
  2. convolve.py +1 -3
  3. stu.py +1 -1
__pycache__/layers.cpython-312.pyc CHANGED
Binary files a/__pycache__/layers.cpython-312.pyc and b/__pycache__/layers.cpython-312.pyc differ
 
convolve.py CHANGED
@@ -41,8 +41,6 @@ def flash_convolve(
41
  u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
42
  ) -> tuple[torch.Tensor, torch.Tensor]:
43
  dtype = u.dtype # Store the original dtype
44
- u = u.to(torch.float32)
45
- v = v.to(torch.float32)
46
 
47
  bsz, seq_len, d_in = u.shape
48
  _, K = v.shape
@@ -50,7 +48,7 @@ def flash_convolve(
50
  padded_len = nearest_power_of_two(seq_len, round_up=True)
51
  pad_len = padded_len - seq_len
52
 
53
- sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=torch.float32)
54
  sgn[:, :, 1::2] = -1
55
 
56
  if use_approx:
 
41
  u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
42
  ) -> tuple[torch.Tensor, torch.Tensor]:
43
  dtype = u.dtype # Store the original dtype
 
 
44
 
45
  bsz, seq_len, d_in = u.shape
46
  _, K = v.shape
 
48
  padded_len = nearest_power_of_two(seq_len, round_up=True)
49
  pad_len = padded_len - seq_len
50
 
51
+ sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=dtype)
52
  sgn[:, :, 1::2] = -1
53
 
54
  if use_approx:
stu.py CHANGED
@@ -30,7 +30,7 @@ class STU(nn.Module):
30
  self.use_hankel_L = config.use_hankel_L
31
  self.use_approx = config.use_approx
32
  self.flash_fft = (
33
- FlashFFTConv(self.n, dtype=torch.bfloat16)
34
  if config.use_flash_fft and flash_fft_available
35
  else None
36
  )
 
30
  self.use_hankel_L = config.use_hankel_L
31
  self.use_approx = config.use_approx
32
  self.flash_fft = (
33
+ FlashFFTConv(self.n, dtype=torch_dtype)
34
  if config.use_flash_fft and flash_fft_available
35
  else None
36
  )