Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from audiotools import AudioSignal | |
from audiotools.ml import BaseModel | |
from encodec import EncodecModel | |
class Encodec(BaseModel): | |
def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): | |
super().__init__() | |
if sample_rate == 24000: | |
self.model = EncodecModel.encodec_model_24khz() | |
else: | |
self.model = EncodecModel.encodec_model_48khz() | |
self.model.set_target_bandwidth(bandwidth) | |
self.sample_rate = 44100 | |
def forward( | |
self, | |
audio_data: torch.Tensor, | |
sample_rate: int = 44100, | |
n_quantizers: int = None, | |
): | |
signal = AudioSignal(audio_data, sample_rate) | |
signal.resample(self.model.sample_rate) | |
recons = self.model(signal.audio_data) | |
recons = AudioSignal(recons, self.model.sample_rate) | |
recons.resample(sample_rate) | |
return {"audio": recons.audio_data} | |
if __name__ == "__main__": | |
import numpy as np | |
from functools import partial | |
model = Encodec() | |
for n, m in model.named_modules(): | |
o = m.extra_repr() | |
p = sum([np.prod(p.size()) for p in m.parameters()]) | |
fn = lambda o, p: o + f" {p/1e6:<.3f}M params." | |
setattr(m, "extra_repr", partial(fn, o=o, p=p)) | |
print(model) | |
print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) | |
length = 88200 * 2 | |
x = torch.randn(1, 1, length).to(model.device) | |
x.requires_grad_(True) | |
x.retain_grad() | |
# Make a forward pass | |
out = model(x)["audio"] | |
print(x.shape, out.shape) | |