import argparse import glob import json import os from pathlib import Path import soundfile as sf from tqdm import tqdm from academicodec.models.hificodec.vqvae_tester import VqvaeTester parser = argparse.ArgumentParser() #Path parser.add_argument('--outputdir', type=str, required=True) parser.add_argument('--model_path', type=str, required=True) parser.add_argument('--input_wavdir', type=str, required=True) parser.add_argument('--config_path', type=str, required=True) parser.add_argument('--num_gens', type=int, default=1024) #Data parser.add_argument('--sample_rate', type=int, default=24000) args = parser.parse_args() with open(args.config_path, 'r') as f: argdict = json.load(f) assert argdict['sampling_rate'] == args.sample_rate, \ f"Sampling rate not consistent, stated {args.sample_rate}, but the model is trained on {argdict['sample_rate']}" argdict.update(args.__dict__) args.__dict__ = argdict if __name__ == '__main__': Path(args.outputdir).mkdir(parents=True, exist_ok=True) print("Init model and load weights") model = VqvaeTester(config_path=args.config_path, model_path=args.model_path,sample_rate=args.sample_rate) model.cuda() model.vqvae.generator.remove_weight_norm() model.vqvae.encoder.remove_weight_norm() model.eval() print("Model ready") wav_paths = glob.glob(f"{args.input_wavdir}/*.wav")[:args.num_gens] print(f"Globbed {len(wav_paths)} wav files.") for wav_path in wav_paths: fid, wav = model(wav_path) wav = wav.squeeze().cpu().numpy() sf.write( os.path.join(args.outputdir, f'{fid}.wav'), wav, args.sample_rate)