Spaces:
Build error
Build error
import torch | |
class FIRFilter(torch.nn.Module): | |
def __init__(self, num_control_params=63): | |
super().__init__() | |
self.num_control_params = num_control_params | |
self.adaptor = torch.nn.Linear(num_control_params, num_control_params) | |
#self.batched_lfilter = torch.vmap(self.lfilter) | |
def forward(self, x, b, **kwargs): | |
"""Forward pass by appling FIR filter to each batch element. | |
Args: | |
x (tensor): Input signals with shape (batch x 1 x samples) | |
b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps) | |
""" | |
bs, ch, s = x.size() | |
b = self.adaptor(b) | |
# pad input | |
x = torch.nn.functional.pad(x, (b.shape[-1] // 2, b.shape[-1] // 2)) | |
# add extra dim for virutal batch dim | |
x = x.view(bs, 1, ch, -1) | |
b = b.view(bs, 1, 1, -1) | |
# exlcuding vmap for now | |
y = self.batched_lfilter(x, b).view(bs, ch, s) | |
return y | |
def lfilter(x, b): | |
return torch.nn.functional.conv1d(x, b) | |
class FrequencyDomainFIRFilter(torch.nn.Module): | |
def __init__(self, num_control_params=31): | |
super().__init__() | |
self.num_control_params = num_control_params | |
self.adaptor = torch.nn.Linear(num_control_params, num_control_params) | |
def forward(self, x, b, **kwargs): | |
"""Forward pass by appling FIR filter to each batch element. | |
Args: | |
x (tensor): Input signals with shape (batch x 1 x samples) | |
b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps) | |
""" | |
bs, c, s = x.size() | |
b = self.adaptor(b) | |
# transform input to freq. domain | |
X = torch.fft.rfft(x.view(bs, -1)) | |
# frequency response of filter | |
H = torch.fft.rfft(b.view(bs, -1)) | |
# apply filter as multiplication in freq. domain | |
Y = X * H | |
# transform back to time domain | |
y = torch.fft.ifft(Y).view(bs, 1, -1) | |
return y | |