|
import os |
|
import numpy as np |
|
import shutil |
|
import resource |
|
import options as opt |
|
|
|
from helpers import * |
|
from datetime import datetime as Datetime |
|
from tensorboardX import SummaryWriter |
|
|
|
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) |
|
resource.setrlimit( |
|
resource.RLIMIT_NOFILE, (65536, rlimit[1]) |
|
) |
|
|
|
|
|
class BaseTrainer(object): |
|
def __init__(self, name='M', base_dir=''): |
|
self.name = name |
|
self.base_dir = base_dir |
|
|
|
self.date_stamp = self.make_date_stamp() |
|
self.save_name = f'{self.name}-{self.date_stamp}' |
|
self.weights_dir = None |
|
self.log_dir = None |
|
self.writer = None |
|
|
|
@staticmethod |
|
def get_dataset_kwargs( |
|
shared_dict=None, base_dir='', |
|
char_map=opt.char_map, **kwargs |
|
): |
|
return kwargify( |
|
video_path=opt.video_path, |
|
shared_dict=shared_dict, |
|
alignments_dir=opt.alignments_dir, |
|
vid_pad=opt.vid_padding, |
|
image_dir=opt.images_dir, |
|
txt_pad=opt.txt_padding, |
|
phonemes_dir=opt.phonemes_dir, |
|
frame_doubling=opt.frame_doubling, |
|
char_map=char_map, |
|
base_dir=base_dir, |
|
**kwargs |
|
) |
|
|
|
def init_tensorboard(self): |
|
self.log_dir = f'runs/{self.save_name}' |
|
self.weights_dir = f'weights/{self.save_name}' |
|
|
|
if not os.path.exists(self.log_dir): |
|
os.mkdir(self.log_dir) |
|
if not os.path.exists(self.weights_dir): |
|
os.mkdir(self.weights_dir) |
|
|
|
self.writer = SummaryWriter(self.log_dir) |
|
|
|
shutil.copyfile( |
|
'options.py', os.path.join(self.log_dir, 'options.py') |
|
) |
|
|
|
@staticmethod |
|
def make_date_stamp(): |
|
return Datetime.now().strftime("%y%m%d-%H%M") |
|
|
|
def log_scalar(self, name, value, iterations, label): |
|
self.writer.add_scalars(name, {label: value}, iterations) |