Spaces:
Sleeping
Sleeping
File size: 4,217 Bytes
86694c3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import json
import os
import pickle
import re
import h5py
import numpy as np
from scipy.io import wavfile
from scipy.io.wavfile import write as write_wav
from tensorflow import keras
from generators.generator import InverSynthGenerator, SoundGenerator, VSTGenerator
from generators.parameters import ParameterSet
"""
This module generates comparisons - takes the original sound + params,
then generates a file with the predicted parameters
"""
def compare(
model: keras.Model,
generator: SoundGenerator,
parameters: ParameterSet,
orig_file: str,
output_dir: str,
orig_params,
length: float,
sample_rate: int,
extra: dict = {},
):
# (copy original file if given)
base_filename = orig_file.replace(".wav", "")
base_filename = re.sub(r".*/", "", base_filename)
copy_file: str = f"{output_dir}/{base_filename}_copy.wav"
regen_file: str = f"{output_dir}/{base_filename}_duplicate.wav"
reconstruct_file: str = f"{output_dir}/{base_filename}_reconstruct.wav"
print(f"Creating copy as {copy_file}")
# Load the wave file
fs, data = wavfile.read(orig_file)
# Copy original file to make sure
write_wav(copy_file, sample_rate, data)
# Decode original params, and regenerate output (make sure its correct)
orig = parameters.encoding_to_settings(orig_params)
generator.generate(orig, regen_file, length, sample_rate, extra)
# Run the wavefile into the model for prediction
X = [data]
Xd = np.expand_dims(np.vstack(X), axis=2)
# Get encoded parameters out of model
result = model.predict(Xd)[0]
# Decode prediction, and reconstruct output
predicted = parameters.encoding_to_settings(result)
generator.generate(predicted, reconstruct_file, length, sample_rate, extra)
def run_comparison(
model: keras.Model,
generator: SoundGenerator,
run_name: str,
indices=None,
num_samples=10,
data_dir="./test_datasets",
output_dir="./comparison",
length=1.0,
sample_rate=16384,
shuffle=True,
extra={},
):
# Figure out data file and params file from run name
data_file = f"{data_dir}/{run_name}_data.hdf5"
parameters_file = f"{data_dir}/{run_name}_params.pckl"
print(f"Reading parameters from {parameters_file}")
parameters = pickle.load(open(parameters_file, "rb"))
output_dir = f"{output_dir}/{run_name}/"
os.makedirs(output_dir, exist_ok=True)
database = h5py.File(data_file, "r")
if not indices:
ids = np.array(range(len(database["files"])))
if shuffle:
np.random.shuffle(ids)
indices = ids[0:num_samples]
# filename
for i in indices:
print("Looking at index: {}".format(i))
filename = database["files"][i]
labels = database["labels"][i]
compare(
model=model,
generator=generator,
parameters=parameters,
orig_file=filename,
output_dir=output_dir,
orig_params=labels,
length=length,
sample_rate=sample_rate,
extra=extra,
)
# Generate
if __name__ == "__main__":
note_length = 0.8
sample_rate = 16384
lokomotiv = True
fm = True
if lokomotiv:
from generators.vst_generator import *
run_name = "lokomotiv_full"
model_file = "output/lokomotiv_full_e2e_best.h5"
plugin = "/Library/Audio/Plug-Ins/VST/Lokomotiv.vst"
config_file = "plugin_config/lokomotiv.json"
generator = VSTGenerator(vst=plugin, sample_rate=sample_rate)
with open(config_file, "r") as f:
config = json.load(f)
model = keras.models.load_model(model_file)
run_comparison(
model,
generator,
run_name,
num_samples=100,
extra={"note_length": note_length, "config": config},
)
if fm:
from generators.fm_generator import *
run_name = "inversynth_full"
model_file = "output/inversynth_full_e2e_best.h5"
generator = InverSynthGenerator()
model = keras.models.load_model(model_file)
run_comparison(model, generator, run_name, num_samples=100)
|