import torch import torch.nn as nn from .convolve import convolve, flash_convolve try: from flashfftconv import FlashFFTConv flash_fft_available = True except ImportError as e: print( f"Unable to import FlashFFTConv: {e}. Falling back to PyTorch implementation." ) flash_fft_available = False class STU(nn.Module): def __init__(self, config, phi, n) -> None: super(STU, self).__init__() self.config = config if isinstance(config.torch_dtype, str): torch_dtype = getattr(torch, config.torch_dtype) else: torch_dtype = config.torch_dtype self.phi = phi.to(device=config.device, dtype=torch_dtype) self.n = n self.K = config.num_eigh self.d_in = config.n_embd self.d_out = config.n_embd self.use_hankel_L = config.use_hankel_L self.use_approx = config.use_approx self.flash_fft = None if config.use_flash_fft and flash_fft_available: if torch_dtype == torch.float16: # Only enable for float16 self.flash_fft = FlashFFTConv(self.n, dtype=torch.float16) else: print(f"Disabling FlashFFTConv for unsupported dtype: {torch_dtype}") if self.use_approx: self.M_inputs = nn.Parameter( torch.empty(self.d_in, self.d_out, dtype=torch_dtype) ) self.M_filters = nn.Parameter( torch.empty(self.K, self.d_in, dtype=torch_dtype) ) else: self.M_phi_plus = nn.Parameter( torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype) ) if not self.use_hankel_L: self.M_phi_minus = nn.Parameter( torch.empty(self.K, self.d_in, self.d_out, dtype=torch_dtype) ) def forward(self, x: torch.Tensor) -> torch.Tensor: dtype = self.M_inputs.dtype x = x.to(dtype=dtype) if self.use_approx: # Contract inputs and filters over the K and d_in dimensions, then convolve x_proj = x @ self.M_inputs phi_proj = self.phi @ self.M_filters x_proj = x_proj.to(dtype=dtype) phi_proj = phi_proj.to(dtype=dtype) if self.flash_fft: spectral_plus, spectral_minus = flash_convolve( x_proj, phi_proj, self.flash_fft, self.use_approx ) else: spectral_plus, spectral_minus = convolve( x_proj, phi_proj, self.n, self.use_approx ) else: # Convolve inputs and filters, if self.flash_fft: U_plus, U_minus = flash_convolve( x, self.phi, self.flash_fft, self.use_approx ) else: U_plus, U_minus = convolve(x, self.phi, self.n, self.use_approx) # Then, contract over the K and d_in dimensions spectral_plus = torch.tensordot( U_plus, self.M_phi_plus, dims=([2, 3], [0, 1]) ) if not self.use_hankel_L: spectral_minus = torch.tensordot( U_minus, self.M_phi_minus, dims=([2, 3], [0, 1]) ) return spectral_plus if self.use_hankel_L else spectral_plus + spectral_minus