File size: 1,662 Bytes
12bfd03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import argparse
import glob
import json
import os
from pathlib import Path

import soundfile as sf
from tqdm import tqdm

from academicodec.models.hificodec.vqvae_tester import VqvaeTester

parser = argparse.ArgumentParser()

#Path
parser.add_argument('--outputdir', type=str, required=True)
parser.add_argument('--model_path', type=str, required=True)
parser.add_argument('--input_wavdir', type=str, required=True)
parser.add_argument('--config_path', type=str, required=True)
parser.add_argument('--num_gens', type=int, default=1024)

#Data
parser.add_argument('--sample_rate', type=int, default=24000)

args = parser.parse_args()

with open(args.config_path, 'r') as f:
    argdict = json.load(f)
    assert argdict['sampling_rate'] == args.sample_rate, \
        f"Sampling rate not consistent, stated {args.sample_rate}, but the model is trained on {argdict['sample_rate']}"
    argdict.update(args.__dict__)
    args.__dict__ = argdict

if __name__ == '__main__':
    Path(args.outputdir).mkdir(parents=True, exist_ok=True)
    print("Init model and load weights")
    model = VqvaeTester(config_path=args.config_path, model_path=args.model_path,sample_rate=args.sample_rate)
    model.cuda()
    model.vqvae.generator.remove_weight_norm()
    model.vqvae.encoder.remove_weight_norm()
    model.eval()
    print("Model ready")

    wav_paths = glob.glob(f"{args.input_wavdir}/*.wav")[:args.num_gens]
    print(f"Globbed {len(wav_paths)} wav files.")

    for wav_path in wav_paths:
        fid, wav = model(wav_path)
        wav = wav.squeeze().cpu().numpy()
        sf.write(
            os.path.join(args.outputdir, f'{fid}.wav'), wav, args.sample_rate)