ljy266987
add lfs
12bfd03
raw
history blame
2.09 kB
import math
import random
import numpy as np
import torch.nn as nn
from academicodec.modules.seanet import SEANetDecoder
from academicodec.modules.seanet import SEANetEncoder
from academicodec.quantization import ResidualVectorQuantizer
# Generator
class SoundStream(nn.Module):
def __init__(self,
n_filters,
D,
target_bandwidths=[7.5, 15],
ratios=[8, 5, 4, 2],
sample_rate=24000,
bins=1024,
normalize=False):
super().__init__()
self.hop_length = np.prod(ratios) # 计算乘积
self.encoder = SEANetEncoder(
n_filters=n_filters, dimension=D, ratios=ratios)
n_q = int(1000 * target_bandwidths[-1] //
(math.ceil(sample_rate / self.hop_length) * 10))
self.frame_rate = math.ceil(sample_rate / np.prod(ratios)) # 75
self.bits_per_codebook = int(math.log2(bins))
self.target_bandwidths = target_bandwidths
self.quantizer = ResidualVectorQuantizer(
dimension=D, n_q=n_q, bins=bins)
self.decoder = SEANetDecoder(
n_filters=n_filters, dimension=D, ratios=ratios)
def get_last_layer(self):
return self.decoder.layers[-1].weight
def forward(self, x):
e = self.encoder(x)
max_idx = len(self.target_bandwidths) - 1
bw = self.target_bandwidths[random.randint(0, max_idx)]
quantized, codes, bandwidth, commit_loss = self.quantizer(
e, self.frame_rate, bw)
o = self.decoder(quantized)
return o, commit_loss, None
def encode(self, x, target_bw=None, st=None):
e = self.encoder(x)
if target_bw is None:
bw = self.target_bandwidths[-1]
else:
bw = target_bw
if st is None:
st = 0
codes = self.quantizer.encode(e, self.frame_rate, bw, st)
return codes
def decode(self, codes):
quantized = self.quantizer.decode(codes)
o = self.decoder(quantized)
return o