Spaces:
Sleeping
Sleeping
Sync from GitHub repo
Browse filesThis Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there
- model/dataset.py +17 -3
model/dataset.py
CHANGED
|
@@ -8,8 +8,10 @@ from torch.utils.data import Dataset, Sampler
|
|
| 8 |
import torchaudio
|
| 9 |
from datasets import load_from_disk
|
| 10 |
from datasets import Dataset as Dataset_
|
|
|
|
| 11 |
|
| 12 |
from model.modules import MelSpec
|
|
|
|
| 13 |
|
| 14 |
|
| 15 |
class HFDataset(Dataset):
|
|
@@ -77,15 +79,22 @@ class CustomDataset(Dataset):
|
|
| 77 |
hop_length=256,
|
| 78 |
n_mel_channels=100,
|
| 79 |
preprocessed_mel=False,
|
|
|
|
| 80 |
):
|
| 81 |
self.data = custom_dataset
|
| 82 |
self.durations = durations
|
| 83 |
self.target_sample_rate = target_sample_rate
|
| 84 |
self.hop_length = hop_length
|
| 85 |
self.preprocessed_mel = preprocessed_mel
|
|
|
|
| 86 |
if not preprocessed_mel:
|
| 87 |
-
self.mel_spectrogram =
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
)
|
| 90 |
|
| 91 |
def get_frame_len(self, index):
|
|
@@ -201,6 +210,7 @@ def load_dataset(
|
|
| 201 |
tokenizer: str = "pinyin",
|
| 202 |
dataset_type: str = "CustomDataset",
|
| 203 |
audio_type: str = "raw",
|
|
|
|
| 204 |
mel_spec_kwargs: dict = dict(),
|
| 205 |
) -> CustomDataset | HFDataset:
|
| 206 |
"""
|
|
@@ -224,7 +234,11 @@ def load_dataset(
|
|
| 224 |
data_dict = json.load(f)
|
| 225 |
durations = data_dict["duration"]
|
| 226 |
train_dataset = CustomDataset(
|
| 227 |
-
train_dataset,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
)
|
| 229 |
|
| 230 |
elif dataset_type == "CustomDatasetPath":
|
|
|
|
| 8 |
import torchaudio
|
| 9 |
from datasets import load_from_disk
|
| 10 |
from datasets import Dataset as Dataset_
|
| 11 |
+
from torch import nn
|
| 12 |
|
| 13 |
from model.modules import MelSpec
|
| 14 |
+
from model.utils import default
|
| 15 |
|
| 16 |
|
| 17 |
class HFDataset(Dataset):
|
|
|
|
| 79 |
hop_length=256,
|
| 80 |
n_mel_channels=100,
|
| 81 |
preprocessed_mel=False,
|
| 82 |
+
mel_spec_module: nn.Module | None = None,
|
| 83 |
):
|
| 84 |
self.data = custom_dataset
|
| 85 |
self.durations = durations
|
| 86 |
self.target_sample_rate = target_sample_rate
|
| 87 |
self.hop_length = hop_length
|
| 88 |
self.preprocessed_mel = preprocessed_mel
|
| 89 |
+
|
| 90 |
if not preprocessed_mel:
|
| 91 |
+
self.mel_spectrogram = default(
|
| 92 |
+
mel_spec_module,
|
| 93 |
+
MelSpec(
|
| 94 |
+
target_sample_rate=target_sample_rate,
|
| 95 |
+
hop_length=hop_length,
|
| 96 |
+
n_mel_channels=n_mel_channels,
|
| 97 |
+
),
|
| 98 |
)
|
| 99 |
|
| 100 |
def get_frame_len(self, index):
|
|
|
|
| 210 |
tokenizer: str = "pinyin",
|
| 211 |
dataset_type: str = "CustomDataset",
|
| 212 |
audio_type: str = "raw",
|
| 213 |
+
mel_spec_module: nn.Module | None = None,
|
| 214 |
mel_spec_kwargs: dict = dict(),
|
| 215 |
) -> CustomDataset | HFDataset:
|
| 216 |
"""
|
|
|
|
| 234 |
data_dict = json.load(f)
|
| 235 |
durations = data_dict["duration"]
|
| 236 |
train_dataset = CustomDataset(
|
| 237 |
+
train_dataset,
|
| 238 |
+
durations=durations,
|
| 239 |
+
preprocessed_mel=preprocessed_mel,
|
| 240 |
+
mel_spec_module=mel_spec_module,
|
| 241 |
+
**mel_spec_kwargs,
|
| 242 |
)
|
| 243 |
|
| 244 |
elif dataset_type == "CustomDatasetPath":
|