|
import math |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .utils import logger |
|
from .utils import get_hankel |
|
|
|
def get_spectral_filters( |
|
seq_len: int, |
|
K: int, |
|
use_hankel_L: bool = False, |
|
device: torch.device = None, |
|
dtype: torch.dtype = torch.bfloat16, |
|
) -> torch.Tensor: |
|
|
|
Z = get_hankel(seq_len, use_hankel_L, device=device, dtype=dtype) |
|
|
|
|
|
Z_float32 = Z.to(torch.float32) |
|
|
|
|
|
sigma, phi = torch.linalg.eigh(Z_float32) |
|
|
|
|
|
sigma = sigma.to(dtype=dtype) |
|
phi = phi.to(dtype=dtype) |
|
|
|
|
|
sigma_k, phi_k = sigma[-K:], phi[:, -K:] |
|
|
|
|
|
phi_k = phi_k * sigma_k ** 0.25 |
|
|
|
|
|
filters = phi_k.to(device=device, dtype=dtype) |
|
|
|
return filters |
|
|
|
|
|
def compute_dimensions(n: int) -> tuple[int, int, int]: |
|
if n <= 2: |
|
raise ValueError("n must be greater than 2") |
|
|
|
T_prime = (math.ceil(math.sqrt(n - 2)))**2 + 2 |
|
sqrt_T_prime = math.ceil(math.sqrt(T_prime - 2)) |
|
k_max = sqrt_T_prime |
|
return T_prime, sqrt_T_prime, k_max |
|
|
|
def get_tensorized_spectral_filters_explicit(n: int, k: int, device: torch.device) -> torch.Tensor: |
|
T_prime, sqrt_T_prime, k_max = compute_dimensions(n) |
|
k = min(k, k_max) |
|
|
|
Z = get_hankel(sqrt_T_prime).to(device) |
|
sigma, phi = torch.linalg.eigh(Z) |
|
sigma_k = sigma[-k:] |
|
phi_k = phi[:, -k:] |
|
|
|
result = torch.zeros(sqrt_T_prime * sqrt_T_prime, device=device) |
|
|
|
for i in range(k): |
|
for j in range(k): |
|
phi_i = phi_k[:, i] * (sigma_k[i] ** 0.25) |
|
phi_j = phi_k[:, j] * (sigma_k[j] ** 0.25) |
|
kron = torch.kron(phi_i, phi_j) |
|
result += kron |
|
|
|
return result |
|
|
|
|
|
def get_tensorized_spectral_filters( |
|
n: int = 8192, |
|
k: int = 24, |
|
use_hankel_L: bool = False, |
|
device: torch.device = None, |
|
dtype: torch.dtype = torch.bfloat16, |
|
) -> torch.Tensor: |
|
""" |
|
Compute tensorized spectral filters for given sequence length and filter count. |
|
|
|
Args: |
|
n: Sequence length |
|
k: Number of filters |
|
use_hankel_L: Hankel_main ⊗ Hankel_L? Default is Hankel_main ⊗ Hankel_main. |
|
device: Computation device |
|
dtype: Computation dtype |
|
""" |
|
assert torch.cuda.is_available(), "CUDA is required." |
|
|
|
T_prime, sqrt_T_prime, k_max = compute_dimensions(n) |
|
k = min(k, k_max) |
|
|
|
Z = get_hankel(sqrt_T_prime) |
|
sigma, phi = torch.linalg.eigh(Z) |
|
phi_i = phi[:, -k:] * sigma[-k:] ** 0.25 |
|
|
|
if use_hankel_L: |
|
logger.info("Mixing Hankel_L with Hankel_main to generate tensorized filters.") |
|
Z_L = get_hankel(sqrt_T_prime, True) |
|
sigma_L, phi_L = torch.linalg.eigh(Z_L) |
|
phi_j = phi_L[:, -k:] * sigma_L[-k:] ** 0.25 |
|
else: |
|
phi_j = phi_i |
|
|
|
filters = torch.kron(phi_i, phi_j) |
|
return filters.to(device=device, dtype=dtype) |
|
|