Spaces:
Runtime error
Runtime error
File size: 1,662 Bytes
12bfd03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import argparse
import glob
import json
import os
from pathlib import Path
import soundfile as sf
from tqdm import tqdm
from academicodec.models.hificodec.vqvae_tester import VqvaeTester
parser = argparse.ArgumentParser()
#Path
parser.add_argument('--outputdir', type=str, required=True)
parser.add_argument('--model_path', type=str, required=True)
parser.add_argument('--input_wavdir', type=str, required=True)
parser.add_argument('--config_path', type=str, required=True)
parser.add_argument('--num_gens', type=int, default=1024)
#Data
parser.add_argument('--sample_rate', type=int, default=24000)
args = parser.parse_args()
with open(args.config_path, 'r') as f:
argdict = json.load(f)
assert argdict['sampling_rate'] == args.sample_rate, \
f"Sampling rate not consistent, stated {args.sample_rate}, but the model is trained on {argdict['sample_rate']}"
argdict.update(args.__dict__)
args.__dict__ = argdict
if __name__ == '__main__':
Path(args.outputdir).mkdir(parents=True, exist_ok=True)
print("Init model and load weights")
model = VqvaeTester(config_path=args.config_path, model_path=args.model_path,sample_rate=args.sample_rate)
model.cuda()
model.vqvae.generator.remove_weight_norm()
model.vqvae.encoder.remove_weight_norm()
model.eval()
print("Model ready")
wav_paths = glob.glob(f"{args.input_wavdir}/*.wav")[:args.num_gens]
print(f"Globbed {len(wav_paths)} wav files.")
for wav_path in wav_paths:
fid, wav = model(wav_path)
wav = wav.squeeze().cpu().numpy()
sf.write(
os.path.join(args.outputdir, f'{fid}.wav'), wav, args.sample_rate)
|