ljy266987
add lfs
12bfd03
raw
history blame
6.47 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Command-line for audio compression."""
import argparse
import os
import sys
import typing as tp
from collections import OrderedDict
from pathlib import Path
import librosa
import soundfile as sf
import torch
from academicodec.models.encodec.net3 import SoundStream
def save_audio(wav: torch.Tensor,
path: tp.Union[Path, str],
sample_rate: int,
rescale: bool=False):
limit = 0.99
mx = wav.abs().max()
if rescale:
wav = wav * min(limit / mx, 1)
else:
wav = wav.clamp(-limit, limit)
wav = wav.squeeze().cpu().numpy()
sf.write(path, wav, sample_rate)
def get_parser():
parser = argparse.ArgumentParser(
'encodec',
description='High fidelity neural audio codec. '
'If input is a .ecdc, decompresses it. '
'If input is .wav, compresses it. If output is also wav, '
'do a compression/decompression cycle.')
parser.add_argument(
'--input',
type=Path,
help='Input file, whatever is supported by torchaudio on your system.')
parser.add_argument(
'--output',
type=Path,
nargs='?',
help='Output file, otherwise inferred from input file.')
parser.add_argument(
'--resume_path', type=str, default='resume_path', help='resume_path')
parser.add_argument(
'--sr', type=int, default=16000, help='sample rate of model')
parser.add_argument(
'-r',
'--rescale',
action='store_true',
help='Automatically rescale the output to avoid clipping.')
parser.add_argument(
'--ratios',
type=int,
nargs='+',
# probs(ratios) = hop_size
default=[8, 5, 4, 2],
help='ratios of SoundStream, shoud be set for different hop_size (32d, 320, 240d, ...)'
)
parser.add_argument(
'--target_bandwidths',
type=float,
nargs='+',
# default for 16k_320d
default=[1, 1.5, 2, 4, 6, 12],
help='target_bandwidths of net3.py')
parser.add_argument(
'--target_bw',
type=float,
# default for 16k_320d
default=12,
help='target_bw of net3.py')
return parser
def fatal(*args):
print(*args, file=sys.stderr)
sys.exit(1)
# 这只是打印了但是没有真的 clip
def check_clipping(wav, rescale):
if rescale:
return
mx = wav.abs().max()
limit = 0.99
if mx > limit:
print(
f"Clipping!! max scale {mx}, limit is {limit}. "
"To avoid clipping, use the `-r` option to rescale the output.",
file=sys.stderr)
def test_one(args, wav_root, store_root, rescale, soundstream):
# torchaudio.load 的采样率为原始音频的采样率,不会自动下采样
# wav, sr = torchaudio.load(wav_root)
# # 取单声道, output shape [1, T]
# wav = wav[0].unsqueeze(0)
# # 重采样为模型的采样率
# wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=args.sr)(wav)
# load wav with librosa
wav, sr = librosa.load(wav_root, sr=args.sr)
wav = torch.tensor(wav).unsqueeze(0)
# add batch axis
wav = wav.unsqueeze(1).cuda()
# compressing
compressed = soundstream.encode(wav, target_bw=args.target_bw)
print('finish compressing')
out = soundstream.decode(compressed)
out = out.detach().cpu().squeeze(0)
check_clipping(out, rescale)
save_audio(wav=out, path=store_root, sample_rate=args.sr, rescale=rescale)
print('finish decompressing')
def remove_encodec_weight_norm(model):
from academicodec.modules import SConv1d
from academicodec.modules.seanet import SConvTranspose1d
from academicodec.modules.seanet import SEANetResnetBlock
from torch.nn.utils import remove_weight_norm
encoder = model.encoder.model
for key in encoder._modules:
if isinstance(encoder._modules[key], SEANetResnetBlock):
remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
block_modules = encoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(encoder._modules[key], SConv1d):
remove_weight_norm(encoder._modules[key].conv.conv)
decoder = model.decoder.model
for key in decoder._modules:
if isinstance(decoder._modules[key], SEANetResnetBlock):
remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
block_modules = decoder._modules[key].block._modules
for skey in block_modules:
if isinstance(block_modules[skey], SConv1d):
remove_weight_norm(block_modules[skey].conv.conv)
elif isinstance(decoder._modules[key], SConvTranspose1d):
remove_weight_norm(decoder._modules[key].convtr.convtr)
elif isinstance(decoder._modules[key], SConv1d):
remove_weight_norm(decoder._modules[key].conv.conv)
def test_batch():
args = get_parser().parse_args()
print("args.target_bandwidths:", args.target_bandwidths)
if not args.input.exists():
fatal(f"Input file {args.input} does not exist.")
input_lists = os.listdir(args.input)
input_lists.sort()
soundstream = SoundStream(
n_filters=32,
D=512,
ratios=args.ratios,
sample_rate=args.sr,
target_bandwidths=args.target_bandwidths)
parameter_dict = torch.load(args.resume_path)
new_state_dict = OrderedDict()
# k 为 module.xxx.weight, v 为权重
for k, v in parameter_dict.items():
# 截取`module.`后面的xxx.weight
name = k[7:]
new_state_dict[name] = v
soundstream.load_state_dict(new_state_dict) # load model
remove_encodec_weight_norm(soundstream)
soundstream.cuda()
soundstream.eval()
os.makedirs(args.output, exist_ok=True)
for audio in input_lists:
test_one(
args=args,
wav_root=os.path.join(args.input, audio),
store_root=os.path.join(args.output, audio),
rescale=args.rescale,
soundstream=soundstream)
if __name__ == '__main__':
test_batch()