GenerSpeech / tasks /vocoder /vocoder_base.py
Rongjiehuang's picture
update
222619b
raw
history blame
2.39 kB
import os
import torch
import torch.distributed as dist
from torch.utils.data import DistributedSampler
from tasks.base_task import BaseTask
from tasks.base_task import data_loader
from tasks.vocoder.dataset_utils import VocoderDataset, EndlessDistributedSampler
from utils.hparams import hparams
class VocoderBaseTask(BaseTask):
def __init__(self):
super(VocoderBaseTask, self).__init__()
self.max_sentences = hparams['max_sentences']
self.max_valid_sentences = hparams['max_valid_sentences']
if self.max_valid_sentences == -1:
hparams['max_valid_sentences'] = self.max_valid_sentences = self.max_sentences
self.dataset_cls = VocoderDataset
@data_loader
def train_dataloader(self):
train_dataset = self.dataset_cls('train', shuffle=True)
return self.build_dataloader(train_dataset, True, self.max_sentences, hparams['endless_ds'])
@data_loader
def val_dataloader(self):
valid_dataset = self.dataset_cls('valid', shuffle=False)
return self.build_dataloader(valid_dataset, False, self.max_valid_sentences)
@data_loader
def test_dataloader(self):
test_dataset = self.dataset_cls('test', shuffle=False)
return self.build_dataloader(test_dataset, False, self.max_valid_sentences)
def build_dataloader(self, dataset, shuffle, max_sentences, endless=False):
world_size = 1
rank = 0
if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
sampler_cls = DistributedSampler if not endless else EndlessDistributedSampler
train_sampler = sampler_cls(
dataset=dataset,
num_replicas=world_size,
rank=rank,
shuffle=shuffle,
)
return torch.utils.data.DataLoader(
dataset=dataset,
shuffle=False,
collate_fn=dataset.collater,
batch_size=max_sentences,
num_workers=dataset.num_workers,
sampler=train_sampler,
pin_memory=True,
)
def test_start(self):
self.gen_dir = os.path.join(hparams['work_dir'],
f'generated_{self.trainer.global_step}_{hparams["gen_dir_name"]}')
os.makedirs(self.gen_dir, exist_ok=True)
def test_end(self, outputs):
return {}