File size: 6,472 Bytes
12bfd03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""Command-line for audio compression."""
import argparse
import os
import sys
import typing as tp
from collections import OrderedDict
from pathlib import Path

import librosa
import soundfile as sf
import torch
from academicodec.models.encodec.net3 import SoundStream


def save_audio(wav: torch.Tensor,
               path: tp.Union[Path, str],
               sample_rate: int,
               rescale: bool=False):
    limit = 0.99
    mx = wav.abs().max()
    if rescale:
        wav = wav * min(limit / mx, 1)
    else:
        wav = wav.clamp(-limit, limit)
    wav = wav.squeeze().cpu().numpy()
    sf.write(path, wav, sample_rate)


def get_parser():
    parser = argparse.ArgumentParser(
        'encodec',
        description='High fidelity neural audio codec. '
        'If input is a .ecdc, decompresses it. '
        'If input is .wav, compresses it. If output is also wav, '
        'do a compression/decompression cycle.')
    parser.add_argument(
        '--input',
        type=Path,
        help='Input file, whatever is supported by torchaudio on your system.')
    parser.add_argument(
        '--output',
        type=Path,
        nargs='?',
        help='Output file, otherwise inferred from input file.')
    parser.add_argument(
        '--resume_path', type=str, default='resume_path', help='resume_path')
    parser.add_argument(
        '--sr', type=int, default=16000, help='sample rate of model')
    parser.add_argument(
        '-r',
        '--rescale',
        action='store_true',
        help='Automatically rescale the output to avoid clipping.')
    parser.add_argument(
        '--ratios',
        type=int,
        nargs='+',
        # probs(ratios) = hop_size
        default=[8, 5, 4, 2],
        help='ratios of SoundStream, shoud be set for different hop_size (32d, 320, 240d, ...)'
    )
    parser.add_argument(
        '--target_bandwidths',
        type=float,
        nargs='+',
        # default for 16k_320d
        default=[1, 1.5, 2, 4, 6, 12],
        help='target_bandwidths of net3.py')
    parser.add_argument(
        '--target_bw',
        type=float,
        # default for 16k_320d
        default=12,
        help='target_bw of net3.py')

    return parser


def fatal(*args):
    print(*args, file=sys.stderr)
    sys.exit(1)


# 这只是打印了但是没有真的 clip
def check_clipping(wav, rescale):
    if rescale:
        return
    mx = wav.abs().max()
    limit = 0.99
    if mx > limit:
        print(
            f"Clipping!! max scale {mx}, limit is {limit}. "
            "To avoid clipping, use the `-r` option to rescale the output.",
            file=sys.stderr)


def test_one(args, wav_root, store_root, rescale, soundstream):
    # torchaudio.load 的采样率为原始音频的采样率,不会自动下采样
    # wav, sr = torchaudio.load(wav_root)
    # # 取单声道, output shape [1, T]
    # wav = wav[0].unsqueeze(0)
    # # 重采样为模型的采样率
    # wav = torchaudio.transforms.Resample(orig_freq=sr, new_freq=args.sr)(wav)

    # load wav with librosa
    wav, sr = librosa.load(wav_root, sr=args.sr)
    wav = torch.tensor(wav).unsqueeze(0)

    # add batch axis
    wav = wav.unsqueeze(1).cuda()

    # compressing
    compressed = soundstream.encode(wav, target_bw=args.target_bw)
    print('finish compressing')
    out = soundstream.decode(compressed)
    out = out.detach().cpu().squeeze(0)
    check_clipping(out, rescale)
    save_audio(wav=out, path=store_root, sample_rate=args.sr, rescale=rescale)
    print('finish decompressing')


def remove_encodec_weight_norm(model):
    from academicodec.modules import SConv1d
    from academicodec.modules.seanet import SConvTranspose1d
    from academicodec.modules.seanet import SEANetResnetBlock
    from torch.nn.utils import remove_weight_norm

    encoder = model.encoder.model
    for key in encoder._modules:
        if isinstance(encoder._modules[key], SEANetResnetBlock):
            remove_weight_norm(encoder._modules[key].shortcut.conv.conv)
            block_modules = encoder._modules[key].block._modules
            for skey in block_modules:
                if isinstance(block_modules[skey], SConv1d):
                    remove_weight_norm(block_modules[skey].conv.conv)
        elif isinstance(encoder._modules[key], SConv1d):
            remove_weight_norm(encoder._modules[key].conv.conv)

    decoder = model.decoder.model
    for key in decoder._modules:
        if isinstance(decoder._modules[key], SEANetResnetBlock):
            remove_weight_norm(decoder._modules[key].shortcut.conv.conv)
            block_modules = decoder._modules[key].block._modules
            for skey in block_modules:
                if isinstance(block_modules[skey], SConv1d):
                    remove_weight_norm(block_modules[skey].conv.conv)
        elif isinstance(decoder._modules[key], SConvTranspose1d):
            remove_weight_norm(decoder._modules[key].convtr.convtr)
        elif isinstance(decoder._modules[key], SConv1d):
            remove_weight_norm(decoder._modules[key].conv.conv)


def test_batch():
    args = get_parser().parse_args()
    print("args.target_bandwidths:", args.target_bandwidths)
    if not args.input.exists():
        fatal(f"Input file {args.input} does not exist.")
    input_lists = os.listdir(args.input)
    input_lists.sort()
    soundstream = SoundStream(
        n_filters=32,
        D=512,
        ratios=args.ratios,
        sample_rate=args.sr,
        target_bandwidths=args.target_bandwidths)
    parameter_dict = torch.load(args.resume_path)
    new_state_dict = OrderedDict()
    # k 为 module.xxx.weight, v 为权重
    for k, v in parameter_dict.items():
        # 截取`module.`后面的xxx.weight
        name = k[7:]
        new_state_dict[name] = v
    soundstream.load_state_dict(new_state_dict)  # load model
    remove_encodec_weight_norm(soundstream)
    soundstream.cuda()
    soundstream.eval()
    os.makedirs(args.output, exist_ok=True)
    for audio in input_lists:
        test_one(
            args=args,
            wav_root=os.path.join(args.input, audio),
            store_root=os.path.join(args.output, audio),
            rescale=args.rescale,
            soundstream=soundstream)


if __name__ == '__main__':
    test_batch()