Spaces:
Build error
Build error
File size: 2,027 Bytes
c983126 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
import torch
class FIRFilter(torch.nn.Module):
def __init__(self, num_control_params=63):
super().__init__()
self.num_control_params = num_control_params
self.adaptor = torch.nn.Linear(num_control_params, num_control_params)
#self.batched_lfilter = torch.vmap(self.lfilter)
def forward(self, x, b, **kwargs):
"""Forward pass by appling FIR filter to each batch element.
Args:
x (tensor): Input signals with shape (batch x 1 x samples)
b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)
"""
bs, ch, s = x.size()
b = self.adaptor(b)
# pad input
x = torch.nn.functional.pad(x, (b.shape[-1] // 2, b.shape[-1] // 2))
# add extra dim for virutal batch dim
x = x.view(bs, 1, ch, -1)
b = b.view(bs, 1, 1, -1)
# exlcuding vmap for now
y = self.batched_lfilter(x, b).view(bs, ch, s)
return y
@staticmethod
def lfilter(x, b):
return torch.nn.functional.conv1d(x, b)
class FrequencyDomainFIRFilter(torch.nn.Module):
def __init__(self, num_control_params=31):
super().__init__()
self.num_control_params = num_control_params
self.adaptor = torch.nn.Linear(num_control_params, num_control_params)
def forward(self, x, b, **kwargs):
"""Forward pass by appling FIR filter to each batch element.
Args:
x (tensor): Input signals with shape (batch x 1 x samples)
b (tensor): Matrix of FIR filter coefficients with shape (batch x ntaps)
"""
bs, c, s = x.size()
b = self.adaptor(b)
# transform input to freq. domain
X = torch.fft.rfft(x.view(bs, -1))
# frequency response of filter
H = torch.fft.rfft(b.view(bs, -1))
# apply filter as multiplication in freq. domain
Y = X * H
# transform back to time domain
y = torch.fft.ifft(Y).view(bs, 1, -1)
return y
|