Commit
·
1177ff0
1
Parent(s):
1f5f496
fix
Browse files- __pycache__/layers.cpython-312.pyc +0 -0
- convolve.py +1 -3
- stu.py +1 -1
__pycache__/layers.cpython-312.pyc
CHANGED
Binary files a/__pycache__/layers.cpython-312.pyc and b/__pycache__/layers.cpython-312.pyc differ
|
|
convolve.py
CHANGED
@@ -41,8 +41,6 @@ def flash_convolve(
|
|
41 |
u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
|
42 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
43 |
dtype = u.dtype # Store the original dtype
|
44 |
-
u = u.to(torch.float32)
|
45 |
-
v = v.to(torch.float32)
|
46 |
|
47 |
bsz, seq_len, d_in = u.shape
|
48 |
_, K = v.shape
|
@@ -50,7 +48,7 @@ def flash_convolve(
|
|
50 |
padded_len = nearest_power_of_two(seq_len, round_up=True)
|
51 |
pad_len = padded_len - seq_len
|
52 |
|
53 |
-
sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=
|
54 |
sgn[:, :, 1::2] = -1
|
55 |
|
56 |
if use_approx:
|
|
|
41 |
u: torch.Tensor, v: torch.Tensor, flash_fft: FlashFFTConv, use_approx: bool = True,
|
42 |
) -> tuple[torch.Tensor, torch.Tensor]:
|
43 |
dtype = u.dtype # Store the original dtype
|
|
|
|
|
44 |
|
45 |
bsz, seq_len, d_in = u.shape
|
46 |
_, K = v.shape
|
|
|
48 |
padded_len = nearest_power_of_two(seq_len, round_up=True)
|
49 |
pad_len = padded_len - seq_len
|
50 |
|
51 |
+
sgn = torch.full((1, 1, padded_len), 1, device=u.device, dtype=dtype)
|
52 |
sgn[:, :, 1::2] = -1
|
53 |
|
54 |
if use_approx:
|
stu.py
CHANGED
@@ -30,7 +30,7 @@ class STU(nn.Module):
|
|
30 |
self.use_hankel_L = config.use_hankel_L
|
31 |
self.use_approx = config.use_approx
|
32 |
self.flash_fft = (
|
33 |
-
FlashFFTConv(self.n, dtype=
|
34 |
if config.use_flash_fft and flash_fft_available
|
35 |
else None
|
36 |
)
|
|
|
30 |
self.use_hankel_L = config.use_hankel_L
|
31 |
self.use_approx = config.use_approx
|
32 |
self.flash_fft = (
|
33 |
+
FlashFFTConv(self.n, dtype=torch_dtype)
|
34 |
if config.use_flash_fft and flash_fft_available
|
35 |
else None
|
36 |
)
|