|
from glob import glob |
|
import os |
|
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard |
|
|
|
|
|
class LSTMChemTrainer(object): |
|
def __init__(self, modeler, train_data_loader, valid_data_loader): |
|
self.model = modeler.model |
|
self.config = modeler.config |
|
self.train_data_loader = train_data_loader |
|
self.valid_data_loader = valid_data_loader |
|
self.callbacks = [] |
|
self.init_callbacks() |
|
|
|
def init_callbacks(self): |
|
self.callbacks.append( |
|
ModelCheckpoint( |
|
filepath=os.path.join( |
|
self.config.checkpoint_dir, |
|
'%s-{epoch:02d}-{val_loss:.2f}.hdf5' % |
|
self.config.exp_name), |
|
monitor=self.config.checkpoint_monitor, |
|
mode=self.config.checkpoint_mode, |
|
save_best_only=self.config.checkpoint_save_best_only, |
|
save_weights_only=self.config.checkpoint_save_weights_only, |
|
verbose=self.config.checkpoint_verbose, |
|
)) |
|
self.callbacks.append( |
|
TensorBoard( |
|
log_dir=self.config.tensorboard_log_dir, |
|
write_graph=self.config.tensorboard_write_graph, |
|
)) |
|
|
|
def train(self): |
|
|
|
history = self.model.fit( |
|
self.train_data_loader, |
|
steps_per_epoch=self.train_data_loader.__len__(), |
|
epochs=self.config.num_epochs, |
|
verbose=self.config.verbose_training, |
|
validation_data=self.valid_data_loader, |
|
validation_steps=self.valid_data_loader.__len__(), |
|
use_multiprocessing=True, |
|
shuffle=True, |
|
callbacks=self.callbacks) |
|
|
|
last_weight_file = glob( |
|
os.path.join( |
|
f'{self.config.checkpoint_dir}', |
|
f'{self.config.exp_name}-{self.config.num_epochs:02}*.hdf5') |
|
)[0] |
|
|
|
assert os.path.exists(last_weight_file) |
|
self.config.model_weight_filename = last_weight_file |
|
|
|
with open(os.path.join(self.config.exp_dir, 'config.json'), 'w') as f: |
|
f.write(self.config.toJSON(indent=2)) |
|
|