Spaces:
Runtime error
Runtime error
import math | |
import random | |
import numpy as np | |
import torch.nn as nn | |
from academicodec.modules.seanet import SEANetDecoder | |
from academicodec.modules.seanet import SEANetEncoder | |
from academicodec.quantization import ResidualVectorQuantizer | |
# Generator | |
class SoundStream(nn.Module): | |
def __init__(self, | |
n_filters, | |
D, | |
target_bandwidths=[7.5, 15], | |
ratios=[8, 5, 4, 2], | |
sample_rate=24000, | |
bins=1024, | |
normalize=False): | |
super().__init__() | |
self.hop_length = np.prod(ratios) # 计算乘积 | |
self.encoder = SEANetEncoder( | |
n_filters=n_filters, dimension=D, ratios=ratios) | |
n_q = int(1000 * target_bandwidths[-1] // | |
(math.ceil(sample_rate / self.hop_length) * 10)) | |
self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 75 | |
self.bits_per_codebook = int(math.log2(bins)) | |
self.target_bandwidths = target_bandwidths | |
self.quantizer = ResidualVectorQuantizer( | |
dimension=D, n_q=n_q, bins=bins) | |
self.decoder = SEANetDecoder( | |
n_filters=n_filters, dimension=D, ratios=ratios) | |
def get_last_layer(self): | |
return self.decoder.layers[-1].weight | |
def forward(self, x): | |
e = self.encoder(x) | |
max_idx = len(self.target_bandwidths) - 1 | |
bw = self.target_bandwidths[random.randint(0, max_idx)] | |
quantized, codes, bandwidth, commit_loss = self.quantizer( | |
e, self.frame_rate, bw) | |
o = self.decoder(quantized) | |
return o, commit_loss, None | |
def encode(self, x, target_bw=None, st=None): | |
e = self.encoder(x) | |
if target_bw is None: | |
bw = self.target_bandwidths[-1] | |
else: | |
bw = target_bw | |
if st is None: | |
st = 0 | |
codes = self.quantizer.encode(e, self.frame_rate, bw, st) | |
return codes | |
def decode(self, codes): | |
quantized = self.quantizer.decode(codes) | |
o = self.decoder(quantized) | |
return o | |