Spaces:
Running
Running
import re | |
import json | |
import fsspec | |
import torch | |
import numpy as np | |
import argparse | |
from argparse import RawTextHelpFormatter | |
from .models.lstm import LSTMSpeakerEncoder | |
from .config import SpeakerEncoderConfig | |
from .utils.audio import AudioProcessor | |
def read_json(json_path): | |
config_dict = {} | |
try: | |
with fsspec.open(json_path, "r", encoding="utf-8") as f: | |
data = json.load(f) | |
except json.decoder.JSONDecodeError: | |
# backwards compat. | |
data = read_json_with_comments(json_path) | |
config_dict.update(data) | |
return config_dict | |
def read_json_with_comments(json_path): | |
"""for backward compat.""" | |
# fallback to json | |
with fsspec.open(json_path, "r", encoding="utf-8") as f: | |
input_str = f.read() | |
# handle comments | |
input_str = re.sub(r"\\\n", "", input_str) | |
input_str = re.sub(r"//.*\n", "\n", input_str) | |
data = json.loads(input_str) | |
return data | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser( | |
description="""Compute embedding vectors for each wav file in a dataset.""", | |
formatter_class=RawTextHelpFormatter, | |
) | |
parser.add_argument("model_path", type=str, help="Path to model checkpoint file.") | |
parser.add_argument( | |
"config_path", | |
type=str, | |
help="Path to model config file.", | |
) | |
parser.add_argument("-s", "--source", help="input wave", dest="source") | |
parser.add_argument( | |
"-t", "--target", help="output 256d speaker embeddimg", dest="target" | |
) | |
parser.add_argument("--use_cuda", type=bool, help="flag to set cuda.", default=True) | |
parser.add_argument("--eval", type=bool, help="compute eval.", default=True) | |
args = parser.parse_args() | |
source_file = args.source | |
target_file = args.target | |
# config | |
config_dict = read_json(args.config_path) | |
# print(config_dict) | |
# model | |
config = SpeakerEncoderConfig(config_dict) | |
config.from_dict(config_dict) | |
speaker_encoder = LSTMSpeakerEncoder( | |
config.model_params["input_dim"], | |
config.model_params["proj_dim"], | |
config.model_params["lstm_dim"], | |
config.model_params["num_lstm_layers"], | |
) | |
speaker_encoder.load_checkpoint(args.model_path, eval=True, use_cuda=args.use_cuda) | |
# preprocess | |
speaker_encoder_ap = AudioProcessor(**config.audio) | |
# normalize the input audio level and trim silences | |
speaker_encoder_ap.do_sound_norm = True | |
speaker_encoder_ap.do_trim_silence = True | |
# compute speaker embeddings | |
# extract the embedding | |
waveform = speaker_encoder_ap.load_wav( | |
source_file, sr=speaker_encoder_ap.sample_rate | |
) | |
spec = speaker_encoder_ap.melspectrogram(waveform) | |
spec = torch.from_numpy(spec.T) | |
if args.use_cuda: | |
spec = spec.cuda() | |
spec = spec.unsqueeze(0) | |
embed = speaker_encoder.compute_embedding(spec).detach().cpu().numpy() | |
embed = embed.squeeze() | |
# print(embed) | |
# print(embed.size) | |
np.save(target_file, embed, allow_pickle=False) | |
if hasattr(speaker_encoder, 'module'): | |
state_dict = speaker_encoder.module.state_dict() | |
else: | |
state_dict = speaker_encoder.state_dict() | |
torch.save({'model': state_dict}, "model_small.pth") | |