infer-vst / back /models /comparison.py
Yann
push backend
86694c3
import json
import os
import pickle
import re
import h5py
import numpy as np
from scipy.io import wavfile
from scipy.io.wavfile import write as write_wav
from tensorflow import keras
from generators.generator import InverSynthGenerator, SoundGenerator, VSTGenerator
from generators.parameters import ParameterSet
"""
This module generates comparisons - takes the original sound + params,
then generates a file with the predicted parameters
"""
def compare(
model: keras.Model,
generator: SoundGenerator,
parameters: ParameterSet,
orig_file: str,
output_dir: str,
orig_params,
length: float,
sample_rate: int,
extra: dict = {},
):
# (copy original file if given)
base_filename = orig_file.replace(".wav", "")
base_filename = re.sub(r".*/", "", base_filename)
copy_file: str = f"{output_dir}/{base_filename}_copy.wav"
regen_file: str = f"{output_dir}/{base_filename}_duplicate.wav"
reconstruct_file: str = f"{output_dir}/{base_filename}_reconstruct.wav"
print(f"Creating copy as {copy_file}")
# Load the wave file
fs, data = wavfile.read(orig_file)
# Copy original file to make sure
write_wav(copy_file, sample_rate, data)
# Decode original params, and regenerate output (make sure its correct)
orig = parameters.encoding_to_settings(orig_params)
generator.generate(orig, regen_file, length, sample_rate, extra)
# Run the wavefile into the model for prediction
X = [data]
Xd = np.expand_dims(np.vstack(X), axis=2)
# Get encoded parameters out of model
result = model.predict(Xd)[0]
# Decode prediction, and reconstruct output
predicted = parameters.encoding_to_settings(result)
generator.generate(predicted, reconstruct_file, length, sample_rate, extra)
def run_comparison(
model: keras.Model,
generator: SoundGenerator,
run_name: str,
indices=None,
num_samples=10,
data_dir="./test_datasets",
output_dir="./comparison",
length=1.0,
sample_rate=16384,
shuffle=True,
extra={},
):
# Figure out data file and params file from run name
data_file = f"{data_dir}/{run_name}_data.hdf5"
parameters_file = f"{data_dir}/{run_name}_params.pckl"
print(f"Reading parameters from {parameters_file}")
parameters = pickle.load(open(parameters_file, "rb"))
output_dir = f"{output_dir}/{run_name}/"
os.makedirs(output_dir, exist_ok=True)
database = h5py.File(data_file, "r")
if not indices:
ids = np.array(range(len(database["files"])))
if shuffle:
np.random.shuffle(ids)
indices = ids[0:num_samples]
# filename
for i in indices:
print("Looking at index: {}".format(i))
filename = database["files"][i]
labels = database["labels"][i]
compare(
model=model,
generator=generator,
parameters=parameters,
orig_file=filename,
output_dir=output_dir,
orig_params=labels,
length=length,
sample_rate=sample_rate,
extra=extra,
)
# Generate
if __name__ == "__main__":
note_length = 0.8
sample_rate = 16384
lokomotiv = True
fm = True
if lokomotiv:
from generators.vst_generator import *
run_name = "lokomotiv_full"
model_file = "output/lokomotiv_full_e2e_best.h5"
plugin = "/Library/Audio/Plug-Ins/VST/Lokomotiv.vst"
config_file = "plugin_config/lokomotiv.json"
generator = VSTGenerator(vst=plugin, sample_rate=sample_rate)
with open(config_file, "r") as f:
config = json.load(f)
model = keras.models.load_model(model_file)
run_comparison(
model,
generator,
run_name,
num_samples=100,
extra={"note_length": note_length, "config": config},
)
if fm:
from generators.fm_generator import *
run_name = "inversynth_full"
model_file = "output/inversynth_full_e2e_best.h5"
generator = InverSynthGenerator()
model = keras.models.load_model(model_file)
run_comparison(model, generator, run_name, num_samples=100)