|
import os |
|
import time |
|
from tensorflow.keras import Sequential |
|
from tensorflow.keras.models import model_from_json |
|
from tensorflow.keras.layers import LSTM, Dense |
|
from tensorflow.keras.initializers import RandomNormal |
|
from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer |
|
|
|
|
|
class LSTMChem(object): |
|
def __init__(self, config, session='train'): |
|
assert session in ['train', 'generate', 'finetune'], \ |
|
'one of {train, generate, finetune}' |
|
|
|
self.config = config |
|
self.session = session |
|
self.model = None |
|
|
|
if self.session == 'train': |
|
self.build_model() |
|
else: |
|
self.model = self.load(self.config.model_arch_filename, |
|
self.config.model_weight_filename) |
|
|
|
def build_model(self): |
|
st = SmilesTokenizer() |
|
n_table = len(st.table) |
|
weight_init = RandomNormal(mean=0.0, |
|
stddev=0.05, |
|
seed=self.config.seed) |
|
|
|
self.model = Sequential() |
|
self.model.add( |
|
LSTM(units=self.config.units, |
|
input_shape=(None, n_table), |
|
return_sequences=True, |
|
kernel_initializer=weight_init, |
|
dropout=0.3)) |
|
self.model.add( |
|
LSTM(units=self.config.units, |
|
input_shape=(None, n_table), |
|
return_sequences=True, |
|
kernel_initializer=weight_init, |
|
dropout=0.5)) |
|
self.model.add( |
|
Dense(units=n_table, |
|
activation='softmax', |
|
kernel_initializer=weight_init)) |
|
|
|
arch = self.model.to_json(indent=2) |
|
self.config.model_arch_filename = os.path.join(self.config.exp_dir, |
|
'model_arch.json') |
|
with open(self.config.model_arch_filename, 'w') as f: |
|
f.write(arch) |
|
|
|
self.model.compile(optimizer=self.config.optimizer, |
|
loss='categorical_crossentropy') |
|
|
|
def save(self, checkpoint_path): |
|
assert self.model, 'You have to build the model first.' |
|
|
|
print('Saving model ...') |
|
self.model.save_weights(checkpoint_path) |
|
print('model saved.') |
|
|
|
def load(self, model_arch_file, checkpoint_file): |
|
print(f'Loading model architecture from {model_arch_file} ...') |
|
with open(model_arch_file) as f: |
|
model = model_from_json(f.read()) |
|
print(f'Loading model checkpoint from {checkpoint_file} ...') |
|
model.load_weights(checkpoint_file) |
|
print('Loaded the Model.') |
|
return model |
|
|