File size: 9,538 Bytes
a03c9b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
""" data_modules.py """
from typing import Optional, Dict, List, Any
import os
import numpy as np
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities import CombinedLoader
from utils.datasets_train import get_cache_data_loader
from utils.datasets_eval import get_eval_dataloader
from utils.datasets_helper import create_merged_train_dataset_info, get_list_of_weighted_random_samplers
from utils.task_manager import TaskManager
from config.config import shared_cfg
from config.config import audio_cfg as default_audio_cfg
from config.data_presets import data_preset_single_cfg, data_preset_multi_cfg
class AMTDataModule(LightningDataModule):
def __init__(
self,
data_home: Optional[os.PathLike] = None,
data_preset_multi: Dict[str, Any] = {
"presets": ["musicnet_mt3_synth_only"],
}, # only allowing multi_preset_cfg. single_preset_cfg should be converted to multi_preset_cfg
task_manager: TaskManager = TaskManager(task_name="mt3_full_plus"),
train_num_samples_per_epoch: Optional[int] = None,
train_random_amp_range: List[float] = [0.6, 1.2],
train_stem_iaug_prob: Optional[float] = 0.7,
train_stem_xaug_policy: Optional[Dict] = {
"max_k": 3,
"tau": 0.3,
"alpha": 1.0,
"max_subunit_stems": 12, # the number of subunit stems to be reduced to this number of stems
"p_include_singing":
0.8, # probability of including singing for cross augmented examples. if None, use base probaility.
"no_instr_overlap": True,
"no_drum_overlap": True,
"uhat_intra_stem_augment": True,
},
train_pitch_shift_range: Optional[List[int]] = None,
audio_cfg: Optional[Dict] = None) -> None:
super().__init__()
# check path existence
if data_home is None:
data_home = shared_cfg["PATH"]["data_home"]
if os.path.exists(data_home):
self.data_home = data_home
else:
raise ValueError(f"Invalid data_home: {data_home}")
self.preset_multi = data_preset_multi
self.preset_singles = []
# e.g. [{"dataset_name": ..., "train_split": ..., "validation_split":...,}, {...}]
for dp in self.preset_multi["presets"]:
if dp not in data_preset_single_cfg.keys():
raise ValueError("Invalid data_preset")
self.preset_singles.append(data_preset_single_cfg[dp])
# task manager
self.task_manager = task_manager
# train num samples per epoch, passed to the sampler
self.train_num_samples_per_epoch = train_num_samples_per_epoch
assert shared_cfg["BSZ"]["train_local"] % shared_cfg["BSZ"]["train_sub"] == 0
self.num_train_samplers = shared_cfg["BSZ"]["train_local"] // shared_cfg["BSZ"]["train_sub"]
# train augmentation parameters
self.train_random_amp_range = train_random_amp_range
self.train_stem_iaug_prob = train_stem_iaug_prob
self.train_stem_xaug_policy = train_stem_xaug_policy
self.train_pitch_shift_range = train_pitch_shift_range
# train data info
self.train_data_info = None # to be set in setup()
# validation/test max num of files
self.val_max_num_files = data_preset_multi.get("val_max_num_files", None)
self.test_max_num_files = data_preset_multi.get("test_max_num_files", None)
# audio config
self.audio_cfg = audio_cfg if audio_cfg is not None else default_audio_cfg
def set_merged_train_data_info(self) -> None:
"""Collect train datasets and create info...
self.train_dataset_info = {
"n_datasets": 0,
"n_notes_per_dataset": [],
"n_files_per_dataset": [],
"dataset_names": [], # dataset names by order of merging file lists
"train_split_names": [], # train split names by order of merging file lists
"index_ranges": [], # index ranges of each dataset in the merged file list
"dataset_weights": [], # pre-defined list of dataset weights for sampling, if available
"merged_file_list": {},
}
"""
self.train_data_info = create_merged_train_dataset_info(self.preset_multi)
print(
f"AMTDataModule: Added {len(self.train_data_info['merged_file_list'])} files from {self.train_data_info['n_datasets']} datasets to the training set."
)
def setup(self, stage: str):
"""
Prepare data args for the dataloaders to be used on each stage.
`stage` is automatically passed by pytorch lightning Trainer.
"""
if stage == "fit":
# Set up train data info
self.set_merged_train_data_info()
# Distributed Weighted random sampler for training
actual_train_num_samples_per_epoch = self.train_num_samples_per_epoch // shared_cfg["BSZ"][
"train_local"] if self.train_num_samples_per_epoch else None
samplers = get_list_of_weighted_random_samplers(num_samplers=self.num_train_samplers,
dataset_weights=self.train_data_info["dataset_weights"],
dataset_index_ranges=self.train_data_info["index_ranges"],
num_samples_per_epoch=actual_train_num_samples_per_epoch)
# Train dataloader arguments
self.train_data_args = []
for sampler in samplers:
self.train_data_args.append({
"dataset_name": None,
"split": None,
"file_list": self.train_data_info["merged_file_list"],
"sub_batch_size": shared_cfg["BSZ"]["train_sub"],
"task_manager": self.task_manager,
"random_amp_range": self.train_random_amp_range, # "0.1,0.5
"stem_iaug_prob": self.train_stem_iaug_prob,
"stem_xaug_policy": self.train_stem_xaug_policy,
"pitch_shift_range": self.train_pitch_shift_range,
"shuffle": True,
"sampler": sampler,
"audio_cfg": self.audio_cfg,
})
# Validation dataloader arguments
self.val_data_args = []
for preset_single in self.preset_singles:
if preset_single["validation_split"] != None:
self.val_data_args.append({
"dataset_name": preset_single["dataset_name"],
"split": preset_single["validation_split"],
"task_manager": self.task_manager,
# "tokenizer": self.task_manager.get_tokenizer(),
"max_num_files": self.val_max_num_files,
"audio_cfg": self.audio_cfg,
})
if stage == "test":
self.test_data_args = []
for preset_single in self.preset_singles:
if preset_single["test_split"] != None:
self.test_data_args.append({
"dataset_name": preset_single["dataset_name"],
"split": preset_single["test_split"],
"task_manager": self.task_manager,
"max_num_files": self.test_max_num_files,
"audio_cfg": self.audio_cfg,
})
def train_dataloader(self) -> Any:
loaders = {}
for i, args_dict in enumerate(self.train_data_args):
loaders[f"data_loader_{i}"] = get_cache_data_loader(**args_dict, dataloader_config=shared_cfg["DATAIO"])
return CombinedLoader(loaders, mode="min_size") # size is always identical
def val_dataloader(self) -> Any:
loaders = {}
for args_dict in self.val_data_args:
dataset_name = args_dict["dataset_name"]
loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"])
return loaders
def test_dataloader(self) -> Any:
loaders = {}
for args_dict in self.test_data_args:
dataset_name = args_dict["dataset_name"]
loaders[dataset_name] = get_eval_dataloader(**args_dict, dataloader_config=shared_cfg["DATAIO"])
return loaders
"""CombinedLoader in "sequential" mode returns dataloader_idx to the
trainer, which is used to get the dataset name in the logger. """
@property
def num_val_dataloaders(self) -> int:
return len(self.val_data_args)
@property
def num_test_dataloaders(self) -> int:
return len(self.test_data_args)
def get_val_dataset_name(self, dataloader_idx: int) -> str:
return self.val_data_args[dataloader_idx]["dataset_name"]
def get_test_dataset_name(self, dataloader_idx: int) -> str:
return self.test_data_args[dataloader_idx]["dataset_name"]
|