import math import numpy as np import torch from .utils import logger from .utils import get_hankel def get_spectral_filters( seq_len: int, K: int, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: # Generate the Hankel matrix using PyTorch Z = get_hankel(seq_len, use_hankel_L, device=device, dtype=dtype) # Cast Z to torch.float32 for the eigenvalue decomposition Z_float32 = Z.to(torch.float32) # Perform eigen decomposition using torch.float32 sigma, phi = torch.linalg.eigh(Z_float32) # Cast the results back to the original dtype (torch.bfloat16) sigma = sigma.to(dtype=dtype) phi = phi.to(dtype=dtype) # Select the top K eigenvalues and eigenvectors sigma_k, phi_k = sigma[-K:], phi[:, -K:] # Compute the spectral filters phi_k = phi_k * sigma_k ** 0.25 # Ensure the filters are in the correct dtype and device filters = phi_k.to(device=device, dtype=dtype) return filters def compute_dimensions(n: int) -> tuple[int, int, int]: if n <= 2: raise ValueError("n must be greater than 2") T_prime = (math.ceil(math.sqrt(n - 2)))**2 + 2 sqrt_T_prime = math.ceil(math.sqrt(T_prime - 2)) k_max = sqrt_T_prime return T_prime, sqrt_T_prime, k_max def get_tensorized_spectral_filters_explicit(n: int, k: int, device: torch.device) -> torch.Tensor: T_prime, sqrt_T_prime, k_max = compute_dimensions(n) k = min(k, k_max) Z = get_hankel(sqrt_T_prime).to(device) sigma, phi = torch.linalg.eigh(Z) sigma_k = sigma[-k:] phi_k = phi[:, -k:] result = torch.zeros(sqrt_T_prime * sqrt_T_prime, device=device) for i in range(k): for j in range(k): phi_i = phi_k[:, i] * (sigma_k[i] ** 0.25) phi_j = phi_k[:, j] * (sigma_k[j] ** 0.25) kron = torch.kron(phi_i, phi_j) result += kron return result def get_tensorized_spectral_filters( n: int = 8192, k: int = 24, use_hankel_L: bool = False, device: torch.device = None, dtype: torch.dtype = torch.bfloat16, ) -> torch.Tensor: """ Compute tensorized spectral filters for given sequence length and filter count. Args: n: Sequence length k: Number of filters use_hankel_L: Hankel_main ⊗ Hankel_L? Default is Hankel_main ⊗ Hankel_main. device: Computation device dtype: Computation dtype """ assert torch.cuda.is_available(), "CUDA is required." T_prime, sqrt_T_prime, k_max = compute_dimensions(n) k = min(k, k_max) Z = get_hankel(sqrt_T_prime) sigma, phi = torch.linalg.eigh(Z) phi_i = phi[:, -k:] * sigma[-k:] ** 0.25 if use_hankel_L: # TODO: We may want to use Hankel_L above too if use_hankel_L is true, make another variable for this (mix != use_hankel_L) logger.info("Mixing Hankel_L with Hankel_main to generate tensorized filters.") Z_L = get_hankel(sqrt_T_prime, True) sigma_L, phi_L = torch.linalg.eigh(Z_L) phi_j = phi_L[:, -k:] * sigma_L[-k:] ** 0.25 else: phi_j = phi_i filters = torch.kron(phi_i, phi_j) return filters.to(device=device, dtype=dtype)