|
import datetime |
|
import json |
|
import os |
|
import pickle as pickle_tts |
|
import shutil |
|
from typing import Any, Callable, Dict, Union |
|
|
|
import fsspec |
|
import torch |
|
from coqpit import Coqpit |
|
|
|
from TTS.utils.generic_utils import get_user_data_dir |
|
|
|
|
|
class RenamingUnpickler(pickle_tts.Unpickler): |
|
"""Overload default pickler to solve module renaming problem""" |
|
|
|
def find_class(self, module, name): |
|
return super().find_class(module.replace("mozilla_voice_tts", "TTS"), name) |
|
|
|
|
|
class AttrDict(dict): |
|
"""A custom dict which converts dict keys |
|
to class attributes""" |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.__dict__ = self |
|
|
|
|
|
def copy_model_files(config: Coqpit, out_path, new_fields=None): |
|
"""Copy config.json and other model files to training folder and add |
|
new fields. |
|
|
|
Args: |
|
config (Coqpit): Coqpit config defining the training run. |
|
out_path (str): output path to copy the file. |
|
new_fields (dict): new fileds to be added or edited |
|
in the config file. |
|
""" |
|
copy_config_path = os.path.join(out_path, "config.json") |
|
|
|
if new_fields: |
|
config.update(new_fields, allow_new=True) |
|
|
|
with fsspec.open(copy_config_path, "w", encoding="utf8") as f: |
|
json.dump(config.to_dict(), f, indent=4) |
|
|
|
|
|
if config.audio.stats_path is not None: |
|
copy_stats_path = os.path.join(out_path, "scale_stats.npy") |
|
filesystem = fsspec.get_mapper(copy_stats_path).fs |
|
if not filesystem.exists(copy_stats_path): |
|
with fsspec.open(config.audio.stats_path, "rb") as source_file: |
|
with fsspec.open(copy_stats_path, "wb") as target_file: |
|
shutil.copyfileobj(source_file, target_file) |
|
|
|
|
|
def load_fsspec( |
|
path: str, |
|
map_location: Union[str, Callable, torch.device, Dict[Union[str, torch.device], Union[str, torch.device]]] = None, |
|
cache: bool = True, |
|
**kwargs, |
|
) -> Any: |
|
"""Like torch.load but can load from other locations (e.g. s3:// , gs://). |
|
|
|
Args: |
|
path: Any path or url supported by fsspec. |
|
map_location: torch.device or str. |
|
cache: If True, cache a remote file locally for subsequent calls. It is cached under `get_user_data_dir()/tts_cache`. Defaults to True. |
|
**kwargs: Keyword arguments forwarded to torch.load. |
|
|
|
Returns: |
|
Object stored in path. |
|
""" |
|
is_local = os.path.isdir(path) or os.path.isfile(path) |
|
if cache and not is_local: |
|
with fsspec.open( |
|
f"filecache::{path}", |
|
filecache={"cache_storage": str(get_user_data_dir("tts_cache"))}, |
|
mode="rb", |
|
) as f: |
|
return torch.load(f, map_location=map_location, **kwargs) |
|
else: |
|
with fsspec.open(path, "rb") as f: |
|
return torch.load(f, map_location=map_location, **kwargs) |
|
|
|
|
|
def load_checkpoint( |
|
model, checkpoint_path, use_cuda=False, eval=False, cache=False |
|
): |
|
try: |
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), cache=cache) |
|
except ModuleNotFoundError: |
|
pickle_tts.Unpickler = RenamingUnpickler |
|
state = load_fsspec(checkpoint_path, map_location=torch.device("cpu"), pickle_module=pickle_tts, cache=cache) |
|
model.load_state_dict(state["model"]) |
|
if use_cuda: |
|
model.cuda() |
|
if eval: |
|
model.eval() |
|
return model, state |
|
|
|
|
|
def save_fsspec(state: Any, path: str, **kwargs): |
|
"""Like torch.save but can save to other locations (e.g. s3:// , gs://). |
|
|
|
Args: |
|
state: State object to save |
|
path: Any path or url supported by fsspec. |
|
**kwargs: Keyword arguments forwarded to torch.save. |
|
""" |
|
with fsspec.open(path, "wb") as f: |
|
torch.save(state, f, **kwargs) |
|
|
|
|
|
def save_model(config, model, optimizer, scaler, current_step, epoch, output_path, **kwargs): |
|
if hasattr(model, "module"): |
|
model_state = model.module.state_dict() |
|
else: |
|
model_state = model.state_dict() |
|
if isinstance(optimizer, list): |
|
optimizer_state = [optim.state_dict() for optim in optimizer] |
|
elif optimizer.__class__.__name__ == "CapacitronOptimizer": |
|
optimizer_state = [optimizer.primary_optimizer.state_dict(), optimizer.secondary_optimizer.state_dict()] |
|
else: |
|
optimizer_state = optimizer.state_dict() if optimizer is not None else None |
|
|
|
if isinstance(scaler, list): |
|
scaler_state = [s.state_dict() for s in scaler] |
|
else: |
|
scaler_state = scaler.state_dict() if scaler is not None else None |
|
|
|
if isinstance(config, Coqpit): |
|
config = config.to_dict() |
|
|
|
state = { |
|
"config": config, |
|
"model": model_state, |
|
"optimizer": optimizer_state, |
|
"scaler": scaler_state, |
|
"step": current_step, |
|
"epoch": epoch, |
|
"date": datetime.date.today().strftime("%B %d, %Y"), |
|
} |
|
state.update(kwargs) |
|
save_fsspec(state, output_path) |
|
|
|
|
|
def save_checkpoint( |
|
config, |
|
model, |
|
optimizer, |
|
scaler, |
|
current_step, |
|
epoch, |
|
output_folder, |
|
**kwargs, |
|
): |
|
file_name = "checkpoint_{}.pth".format(current_step) |
|
checkpoint_path = os.path.join(output_folder, file_name) |
|
print("\n > CHECKPOINT : {}".format(checkpoint_path)) |
|
save_model( |
|
config, |
|
model, |
|
optimizer, |
|
scaler, |
|
current_step, |
|
epoch, |
|
checkpoint_path, |
|
**kwargs, |
|
) |
|
|
|
|
|
def save_best_model( |
|
current_loss, |
|
best_loss, |
|
config, |
|
model, |
|
optimizer, |
|
scaler, |
|
current_step, |
|
epoch, |
|
out_path, |
|
keep_all_best=False, |
|
keep_after=10000, |
|
**kwargs, |
|
): |
|
if current_loss < best_loss: |
|
best_model_name = f"best_model_{current_step}.pth" |
|
checkpoint_path = os.path.join(out_path, best_model_name) |
|
print(" > BEST MODEL : {}".format(checkpoint_path)) |
|
save_model( |
|
config, |
|
model, |
|
optimizer, |
|
scaler, |
|
current_step, |
|
epoch, |
|
checkpoint_path, |
|
model_loss=current_loss, |
|
**kwargs, |
|
) |
|
fs = fsspec.get_mapper(out_path).fs |
|
|
|
if not keep_all_best or (current_step < keep_after): |
|
model_names = fs.glob(os.path.join(out_path, "best_model*.pth")) |
|
for model_name in model_names: |
|
if os.path.basename(model_name) != best_model_name: |
|
fs.rm(model_name) |
|
|
|
shortcut_name = "best_model.pth" |
|
shortcut_path = os.path.join(out_path, shortcut_name) |
|
fs.copy(checkpoint_path, shortcut_path) |
|
best_loss = current_loss |
|
return best_loss |
|
|