{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wWSzEAjavzrb", "outputId": "3c8987e4-5537-46ed-fc4b-0988e03b376c" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: bunch in /usr/local/lib/python3.8/dist-packages (1.0.1)\n" ] } ], "source": [ "\n", "pip install bunch" ] }, { "cell_type": "code", "source": [ "from google.colab import drive\n", "drive.mount('/content/drive')" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "PXp_V--04b4N", "outputId": "35b31fa0-1cde-4ddc-b82d-af62805a5c76" }, "execution_count": null, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount(\"/content/drive\", force_remount=True).\n" ] } ] }, { "cell_type": "markdown", "metadata": { "id": "7mgq8uVbyW4d" }, "source": [ "## Utility Codes" ] }, { "cell_type": "markdown", "metadata": { "id": "oyGzbJOiyau5" }, "source": [ "### Parameter Setting" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FSYvCcgkvxtK" }, "outputs": [], "source": [ "import os\n", "import time\n", "import json\n", "from bunch import Bunch\n", "\n", "import pandas as pd\n", "\n", "df= pd.read_csv('vocab.csv')\n", "vocab = df['Smiles'].tolist()\n", "\n", "def get_config_from_json(json_file):\n", " with open(json_file, 'r') as config_file:\n", " config_dict = json.load(config_file)\n", " config = Bunch(config_dict)\n", " return config\n", "\n", "\n", "def process_config(json_file):\n", " config = get_config_from_json(json_file)\n", " config.config_file = json_file\n", " config.exp_dir = os.path.join(\n", " 'experiments', time.strftime('%Y-%m-%d/', time.localtime()),\n", " config.exp_name)\n", " config.tensorboard_log_dir = os.path.join(\n", " 'experiments', time.strftime('%Y-%m-%d/', time.localtime()),\n", " config.exp_name, 'logs/')\n", " config.checkpoint_dir = os.path.join(\n", " 'experiments', time.strftime('%Y-%m-%d/', time.localtime()),\n", " config.exp_name, 'checkpoints/')\n", " return config" ] }, { "cell_type": "markdown", "metadata": { "id": "tJp7vxoyyhej" }, "source": [ "### Creating Directory" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Bx9GeQ9qwEXy" }, "outputs": [], "source": [ "import os\n", "import sys\n", "\n", "\n", "def create_dirs(dirs):\n", " try:\n", " for dir_ in dirs:\n", " if not os.path.exists(dir_):\n", " os.makedirs(dir_)\n", " except Exception as err:\n", " print(f'Creating directories error: {err}')\n", " sys.exit()" ] }, { "cell_type": "markdown", "metadata": { "id": "7ME9uPedQj00" }, "source": [ "## Building Function Vocabulary" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 424 }, "id": "5MokMxOEQiaO", "outputId": "d4eb662d-400e-41f5-ed34-f1278ab6520b" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Unnamed: 0 Len Functional grp name \\\n", "0 0 7 Acid anhydrides \n", "1 1 8 Diacyl peroxides \n", "2 2 15 Phenyl esters \n", "3 3 5 Isocyanates \n", "4 4 8 β-Lactams \n", "... ... ... ... \n", "1240 1240 12 SA46 \n", "1241 1241 11 SA47 \n", "1242 1242 19 SA48 \n", "1243 1243 15 SA49 \n", "1244 1244 6 Primary amines (strict) \n", "\n", " SMARTS Smiles \\\n", "0 [CX3](=[OX1])[OX2][CX3](=[OX1]) O=COC=O \n", "1 [CX3](=[OX1])[OX2][OX2][CX3](=[OX1]) O=COOC=O \n", "2 [cR1]1([OX2][CX3](=[OX1])[#6])[cR1][cR1][cR1][... CC(=O)Oc1ccccc1 \n", "3 [NX2]=[CX2]=[OX1] N=C=O \n", "4 [NX3]1[CX4][CX4][CX3]1(=[OX1]) O=C1CCN1 \n", "... ... ... \n", "1240 [#6]-1-2-[#6]=[#6]-[#6](-[#6]-1)-[#6]-[#6]-2 C1=CC2CCC1C2 \n", "1241 [#6]-1-2-[#6]-[#6]-[#6](-[#6]-1)-[#6]-[#6]-2 C1CC2CCC1C2 \n", "1242 c1(c(cccc1)-[#8]-c2ccccc2)-[Br] Brc1ccccc1Oc1ccccc1 \n", "1243 c1-2c(cccc1)-[#6]-[#6]-[#6]-2 c1ccc2c(c1)CCC2 \n", "1244 [NX3H2][CX4&!$([CX4]([NH2])[O,N,S,P])] C[NH2] \n", "\n", " Validity \n", "0 Valid \n", "1 Valid \n", "2 Valid \n", "3 Valid \n", "4 Valid \n", "... ... \n", "1240 Valid \n", "1241 Valid \n", "1242 Valid \n", "1243 Valid \n", "1244 Valid \n", "\n", "[1245 rows x 6 columns]" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Unnamed: 0LenFunctional grp nameSMARTSSmilesValidity
007Acid anhydrides[CX3](=[OX1])[OX2][CX3](=[OX1])O=COC=OValid
118Diacyl peroxides[CX3](=[OX1])[OX2][OX2][CX3](=[OX1])O=COOC=OValid
2215Phenyl esters[cR1]1([OX2][CX3](=[OX1])[#6])[cR1][cR1][cR1][...CC(=O)Oc1ccccc1Valid
335Isocyanates[NX2]=[CX2]=[OX1]N=C=OValid
448β-Lactams[NX3]1[CX4][CX4][CX3]1(=[OX1])O=C1CCN1Valid
.....................
1240124012SA46[#6]-1-2-[#6]=[#6]-[#6](-[#6]-1)-[#6]-[#6]-2C1=CC2CCC1C2Valid
1241124111SA47[#6]-1-2-[#6]-[#6]-[#6](-[#6]-1)-[#6]-[#6]-2C1CC2CCC1C2Valid
1242124219SA48c1(c(cccc1)-[#8]-c2ccccc2)-[Br]Brc1ccccc1Oc1ccccc1Valid
1243124315SA49c1-2c(cccc1)-[#6]-[#6]-[#6]-2c1ccc2c(c1)CCC2Valid
124412446Primary amines (strict)[NX3H2][CX4&!$([CX4]([NH2])[O,N,S,P])]C[NH2]Valid
\n", "

1245 rows × 6 columns

\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 36 } ], "source": [ "import pandas as pd\n", "\n", "df = pd.read_csv('functions.csv')\n", "df\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "VpdiNPeLQiNb", "outputId": "a6cd7530-4f57-4c17-a870-cc846e5d0c16" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "(673, 6)\n" ] } ], "source": [ "df_filtered = df[df.Len < 13]\n", "print(df_filtered.shape)\n", "vocab = df_filtered['Smiles'].tolist()" ] }, { "cell_type": "markdown", "metadata": { "id": "82goiSvmynck" }, "source": [ "## SMILES TOKENIZER" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CCJOzUKwuydA" }, "outputs": [], "source": [ "import numpy as np\n", "\n", "\n", "class SmilesTokenizer(object):\n", " def __init__(self):\n", " atoms = [\n", " 'Al', 'As', 'B', 'Br', 'C', 'Cl', 'F', 'H', 'I', 'K', 'Li', 'N',\n", " 'Na', 'O', 'P', 'S', 'Se', 'Si', 'Te'\n", " ]\n", " special = [\n", " '(', ')', '[', ']', '=', '#', '%', '0', '1', '2', '3', '4', '5',\n", " '6', '7', '8', '9', '+', '-', 'se', 'te', 'c', 'n', 'o', 's'\n", " ]\n", "\n", " function = vocab\n", "\n", "\n", "\n", " padding = ['G', 'A', 'E']\n", "\n", " self.table = sorted(function, key=len, reverse=True) + sorted(vocab, key=len, reverse=True) + padding + special\n", " table_len = len(self.table)\n", "\n", " self.table_12_chars = list(filter(lambda x: len(x) == 12, self.table))\n", " self.table_11_chars = list(filter(lambda x: len(x) == 11, self.table))\n", " self.table_10_chars = list(filter(lambda x: len(x) == 10, self.table))\n", " self.table_9_chars = list(filter(lambda x: len(x) == 9, self.table))\n", " self.table_8_chars = list(filter(lambda x: len(x) == 8, self.table))\n", " self.table_7_chars = list(filter(lambda x: len(x) == 7, self.table))\n", " self.table_6_chars = list(filter(lambda x: len(x) == 6, self.table))\n", " self.table_5_chars = list(filter(lambda x: len(x) == 5, self.table))\n", " self.table_4_chars = list(filter(lambda x: len(x) == 4, self.table))\n", " self.table_3_chars = list(filter(lambda x: len(x) == 3, self.table))\n", " self.table_2_chars = list(filter(lambda x: len(x) == 2, self.table))\n", " self.table_1_chars = list(filter(lambda x: len(x) == 1, self.table))\n", "\n", "\n", "\n", " self.one_hot_dict = {}\n", " for i, symbol in enumerate(self.table):\n", " vec = np.zeros(table_len, dtype=np.float32)\n", " vec[i] = 1\n", " self.one_hot_dict[symbol] = vec\n", "\n", " def tokenize(self, smiles):\n", "\n", " N = len(smiles)\n", " token = []\n", " i = 0\n", " while (i < N):\n", " c1 = smiles[i]\n", " c2 = smiles[i:i + 2]\n", " c3 = smiles[i:i + 3]\n", " c4 = smiles[i:i + 4]\n", " c5 = smiles[i:i + 5]\n", " c6 = smiles[i:i + 6]\n", " c7 = smiles[i:i + 7]\n", " c8 = smiles[i:i + 8]\n", " c9 = smiles[i:i + 9]\n", " c10 = smiles[i:i + 10]\n", " c11 = smiles[i:i + 11]\n", "\n", " c12= smiles[i:i + 12]\n", "\n", " if c12 in self.table_12_chars:\n", " token.append(c12)\n", " i += 12\n", " continue\n", "\n", " if c11 in self.table_11_chars:\n", " token.append(c11)\n", " i += 11\n", " continue\n", "\n", " if c10 in self.table_10_chars:\n", " token.append(c10)\n", " i += 10\n", " continue\n", "\n", " if c9 in self.table_9_chars:\n", " token.append(c9)\n", " i += 9\n", " continue \n", " \n", " if c8 in self.table_8_chars:\n", " token.append(c8)\n", " i += 8\n", " continue\n", "\n", " if c7 in self.table_7_chars:\n", " token.append(c7)\n", " i += 7\n", " continue\n", "\n", " if c6 in self.table_6_chars:\n", " token.append(c6)\n", " i += 6\n", " continue\n", "\n", " if c5 in self.table_5_chars:\n", " token.append(c5)\n", " i += 5\n", " continue\n", "\n", " if c4 in self.table_4_chars:\n", " token.append(c4)\n", " i += 4\n", " continue\n", "\n", " if c3 in self.table_3_chars:\n", " token.append(c3)\n", " i += 3\n", " continue \n", "\n", "\n", " if c2 in self.table_2_chars:\n", " token.append(c2)\n", " i += 2\n", " continue\n", "\n", " if c1 in self.table_1_chars:\n", " token.append(c1)\n", " i += 1\n", " continue\n", "\n", "\n", "\n", " i += 1\n", " \n", " #print(token)\n", " return token\n", "\n", "\n", "\n", " def one_hot_encode(self, tokenized_smiles):\n", " result = np.array(\n", " [self.one_hot_dict[symbol] for symbol in tokenized_smiles],\n", " dtype=np.float32)\n", " result = result.reshape(1, result.shape[0], result.shape[1])\n", " ##print(result)\n", " return result\n" ] }, { "cell_type": "markdown", "metadata": { "id": "9D0FvvsmywFk" }, "source": [ "## Data Loader" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "N0m-AmjPufqR" }, "outputs": [], "source": [ "\n", "import json\n", "import os\n", "import numpy as np\n", "from tqdm import tqdm\n", "from tensorflow.keras.utils import Sequence\n", "#from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer\n", "\n", "\n", "class DataLoader(Sequence):\n", " def __init__(self, config, data_type='train'):\n", " self.config = config\n", " self.data_type = data_type\n", " assert self.data_type in ['train', 'valid', 'finetune']\n", "\n", " self.max_len = 0\n", "\n", " if self.data_type == 'train':\n", " self.smiles = self._load(self.config.data_filename)\n", " elif self.data_type == 'finetune':\n", " self.smiles = self._load(self.config.finetune_data_filename)\n", " else:\n", " pass\n", "\n", " self.st = SmilesTokenizer()\n", " self.one_hot_dict = self.st.one_hot_dict\n", "\n", " self.tokenized_smiles = self._tokenize(self.smiles)\n", "\n", " if self.data_type in ['train', 'valid']:\n", " self.idx = np.arange(len(self.tokenized_smiles))\n", " self.valid_size = int(\n", " np.ceil(\n", " len(self.tokenized_smiles) * self.config.validation_split))\n", " np.random.seed(self.config.seed)\n", " np.random.shuffle(self.idx)\n", "\n", " def _set_data(self):\n", " if self.data_type == 'train':\n", " ret = [\n", " self.tokenized_smiles[self.idx[i]]\n", " for i in self.idx[self.valid_size:]\n", " ]\n", " elif self.data_type == 'valid':\n", " ret = [\n", " self.tokenized_smiles[self.idx[i]]\n", " for i in self.idx[:self.valid_size]\n", " ]\n", " else:\n", " ret = self.tokenized_smiles\n", " return ret\n", "\n", " def _load(self, data_filename):\n", " length = self.config.data_length\n", " print('loading SMILES...')\n", " with open(data_filename) as f:\n", " smiles = [s.rstrip() for s in f]\n", " if length != 0:\n", " smiles = smiles[:length]\n", " print('done.')\n", " return smiles\n", "\n", " def _tokenize(self, smiles):\n", " assert isinstance(smiles, list)\n", " print('tokenizing SMILES...')\n", " tokenized_smiles = [self.st.tokenize(smi) for smi in tqdm(smiles)]\n", "\n", " if self.data_type == 'train':\n", " for tokenized_smi in tokenized_smiles:\n", " length = len(tokenized_smi)\n", " if self.max_len < length:\n", " self.max_len = length\n", " self.config.train_smi_max_len = self.max_len\n", " print('done.')\n", " return tokenized_smiles\n", "\n", " def __len__(self):\n", " target_tokenized_smiles = self._set_data()\n", " if self.data_type in ['train', 'valid']:\n", " ret = int(\n", " np.ceil(\n", " len(target_tokenized_smiles) /\n", " float(self.config.batch_size)))\n", " else:\n", " ret = int(\n", " np.ceil(\n", " len(target_tokenized_smiles) /\n", " float(self.config.finetune_batch_size)))\n", " return ret\n", "\n", " def __getitem__(self, idx):\n", " target_tokenized_smiles = self._set_data()\n", " if self.data_type in ['train', 'valid']:\n", " data = target_tokenized_smiles[idx *\n", " self.config.batch_size:(idx + 1) *\n", " self.config.batch_size]\n", " else:\n", " data = target_tokenized_smiles[idx *\n", " self.config.finetune_batch_size:\n", " (idx + 1) *\n", " self.config.finetune_batch_size]\n", " data = self._padding(data)\n", "\n", " self.X, self.y = [], []\n", " for tp_smi in data:\n", " X = [self.one_hot_dict[symbol] for symbol in tp_smi[:-1]]\n", " self.X.append(X)\n", " y = [self.one_hot_dict[symbol] for symbol in tp_smi[1:]]\n", " self.y.append(y)\n", "\n", " self.X = np.array(self.X, dtype=np.float32)\n", " self.y = np.array(self.y, dtype=np.float32)\n", " \n", "# return self.X, self.y, [None]\n", " return self.X, self.y\n", "\n", " def _pad(self, tokenized_smi):\n", " return ['G'] + tokenized_smi + ['E'] + [\n", " 'A' for _ in range(self.max_len - len(tokenized_smi))\n", " ]\n", "\n", " def _padding(self, data):\n", " padded_smiles = [self._pad(t_smi) for t_smi in data]\n", " return padded_smiles" ] }, { "cell_type": "markdown", "metadata": { "id": "NbUu9dy63XKJ" }, "source": [ "## LSTM Model" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "IymzAnUUvG1v" }, "outputs": [], "source": [ "import os\n", "import time\n", "from tensorflow.keras import Sequential\n", "from tensorflow.keras.models import model_from_json\n", "from tensorflow.keras.layers import LSTM, Dense\n", "from tensorflow.keras.initializers import RandomNormal\n", "\n", "\n", "\n", "class LSTMChem(object):\n", " def __init__(self, config, session='train'):\n", " assert session in ['train', 'generate', 'finetune'], \\\n", " 'one of {train, generate, finetune}'\n", "\n", " self.config = config\n", " self.session = session\n", " self.model = None\n", "\n", " if self.session == 'train':\n", " self.build_model()\n", " else:\n", " self.model = self.load(self.config.model_arch_filename,\n", " self.config.model_weight_filename)\n", "\n", " def build_model(self):\n", " st = SmilesTokenizer()\n", " n_table = len(st.table)\n", " weight_init = RandomNormal(mean=0.0,\n", " stddev=0.05,\n", " seed=self.config.seed)\n", "\n", " self.model = Sequential()\n", " self.model.add(\n", " LSTM(units=self.config.units,\n", " input_shape=(None, n_table),\n", " return_sequences=True,\n", " kernel_initializer=weight_init,\n", " dropout=0.1))\n", " self.model.add(\n", " LSTM(units=self.config.units,\n", " input_shape=(None, n_table),\n", " return_sequences=True,\n", " kernel_initializer=weight_init,\n", " dropout=0.2))\n", " self.model.add(\n", " Dense(units=n_table,\n", " activation='softmax',\n", " kernel_initializer=weight_init))\n", "\n", " arch = self.model.to_json(indent=2)\n", " self.config.model_arch_filename = os.path.join(self.config.exp_dir,\n", " 'model_arch.json')\n", " with open(self.config.model_arch_filename, 'w') as f:\n", " f.write(arch)\n", "\n", " self.model.compile(optimizer=self.config.optimizer,\n", " loss='categorical_crossentropy')\n", "\n", " def save(self, checkpoint_path):\n", " assert self.model, 'You have to build the model first.'\n", "\n", " print('Saving model ...')\n", " self.model.save_weights(checkpoint_path)\n", " print('model saved.')\n", "\n", " def load(self, model_arch_file, checkpoint_file):\n", " print(f'Loading model architecture from {model_arch_file} ...')\n", " with open(model_arch_file) as f:\n", " model = model_from_json(f.read())\n", " print(f'Loading model checkpoint from {checkpoint_file} ...')\n", " model.load_weights(checkpoint_file)\n", " print('Loaded the Model.')\n", " return model" ] }, { "cell_type": "markdown", "metadata": { "id": "GMH8I9iB3uSu" }, "source": [ "## Generator" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bltRLOMevG9D" }, "outputs": [], "source": [ "from tqdm import tqdm\n", "import numpy as np\n", "#from lstm_chem.utils.smiles_tokenizer2 import SmilesTokenizer\n", "\n", "\n", "class LSTMChemGenerator(object):\n", " def __init__(self, modeler):\n", " self.session = modeler.session\n", " self.model = modeler.model\n", " self.config = modeler.config\n", " self.st = SmilesTokenizer()\n", "\n", " def _generate(self, sequence):\n", " while (sequence[-1] != 'E') and (len(self.st.tokenize(sequence)) <=\n", " self.config.smiles_max_length):\n", " x = self.st.one_hot_encode(self.st.tokenize(sequence))\n", " preds = self.model.predict_on_batch(x)[0][-1]\n", " next_idx = self.sample_with_temp(preds)\n", " sequence += self.st.table[next_idx]\n", "\n", " sequence = sequence[1:].rstrip('E')\n", "\n", " return sequence\n", "\n", "\n", " def sample_with_temp(self, preds):\n", " streched = np.log(preds) / self.config.sampling_temp\n", " streched_probs = np.exp(streched) / np.sum(np.exp(streched))\n", " return np.random.choice(range(len(streched)), p=streched_probs)\n", "\n", " def sample(self, num=1, start='G'):\n", " sampled = []\n", " if self.session == 'generate':\n", " for _ in tqdm(range(num)):\n", " sampled.append(self._generate(start))\n", " return sampled\n", " else:\n", " from rdkit import Chem, RDLogger\n", " RDLogger.DisableLog('rdApp.*')\n", " while len(sampled) < num:\n", " sequence = self._generate(start)\n", " mol = Chem.MolFromSmiles(sequence)\n", " if mol is not None:\n", " canon_smiles = Chem.MolToSmiles(mol)\n", " sampled.append(canon_smiles)\n", " return sampled" ] }, { "cell_type": "markdown", "metadata": { "id": "PQuZ2lj-2y7g" }, "source": [ "## Fine Tuner" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XIL7Gx05vG55" }, "outputs": [], "source": [ "\n", "\n", "\n", "class LSTMChemFinetuner(LSTMChemGenerator):\n", " def __init__(self, modeler, finetune_data_loader):\n", " self.session = modeler.session\n", " self.model = modeler.model\n", " self.config = modeler.config\n", " self.finetune_data_loader = finetune_data_loader\n", " self.st = SmilesTokenizer()\n", "\n", " def finetune(self):\n", " self.model.compile(optimizer=self.config.optimizer,\n", " loss='categorical_crossentropy')\n", "\n", "# history = self.model.fit_generator(\n", " history = self.model.fit(\n", " self.finetune_data_loader,\n", " steps_per_epoch=self.finetune_data_loader.__len__(),\n", " epochs=self.config.finetune_epochs,\n", " verbose=self.config.verbose_training,\n", " use_multiprocessing=True,\n", " shuffle=True)\n", " return history\n" ] }, { "cell_type": "markdown", "metadata": { "id": "rPI7-OzS3bWd" }, "source": [ "## Trainer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bpmPDt1yvGqT" }, "outputs": [], "source": [ "from glob import glob\n", "import os\n", "from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard\n", "\n", "\n", "class LSTMChemTrainer(object):\n", " def __init__(self, modeler, train_data_loader, valid_data_loader):\n", " self.model = modeler.model\n", " self.config = modeler.config\n", " self.train_data_loader = train_data_loader\n", " self.valid_data_loader = valid_data_loader\n", " self.callbacks = []\n", " self.init_callbacks()\n", "\n", " def init_callbacks(self):\n", " self.callbacks.append(\n", " ModelCheckpoint(\n", " filepath=os.path.join(\n", " self.config.checkpoint_dir,\n", " '%s-{epoch:02d}-{val_loss:.2f}.hdf5' %\n", " self.config.exp_name),\n", " monitor=self.config.checkpoint_monitor,\n", " mode=self.config.checkpoint_mode,\n", " save_best_only=self.config.checkpoint_save_best_only,\n", " save_weights_only=self.config.checkpoint_save_weights_only,\n", " verbose=self.config.checkpoint_verbose,\n", " ))\n", " self.callbacks.append(\n", " TensorBoard(\n", " log_dir=self.config.tensorboard_log_dir,\n", " write_graph=self.config.tensorboard_write_graph,\n", " ))\n", "\n", " def train(self):\n", "# history = self.model.fit_generator(\n", " history = self.model.fit(\n", " self.train_data_loader,\n", " steps_per_epoch=self.train_data_loader.__len__(),\n", " epochs=self.config.num_epochs,\n", " verbose=self.config.verbose_training,\n", " validation_data=self.valid_data_loader,\n", " validation_steps=self.valid_data_loader.__len__(),\n", " use_multiprocessing=True,\n", " shuffle=True,\n", " callbacks=self.callbacks)\n", "\n", " last_weight_file = glob(\n", " os.path.join(\n", " f'{self.config.checkpoint_dir}',\n", " f'{self.config.exp_name}-{self.config.num_epochs:02}*.hdf5')\n", " )[0]\n", "\n", " assert os.path.exists(last_weight_file)\n", " self.config.model_weight_filename = last_weight_file\n", "\n", " with open(os.path.join(self.config.exp_dir, 'config.json'), 'w') as f:\n", " f.write(self.config.toJSON(indent=2))" ] }, { "cell_type": "markdown", "metadata": { "id": "oUg2HPWy3n8K" }, "source": [ "## Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 105 }, "id": "NVQQ-UBUvpZf", "outputId": "5595c9cc-a9d3-4c1d-c937-ad64e5aab838" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "\"\\nfrom copy import copy\\n#from lstm_chem.utils.config import process_config\\n#from lstm_chem.utils.dirs import create_dirs\\n#from lstm_chem.data_loader import DataLoader\\n#from lstm_chem.model import LSTMChem\\n#from lstm_chem.trainer import LSTMChemTrainer\\n\\nCONFIG_FILE = '/content/base_config.json'\\n\\n\\ndef main():\\n config = process_config(CONFIG_FILE)\\n\\n # create the experiments dirs\\n create_dirs(\\n [config.exp_dir, config.tensorboard_log_dir, config.checkpoint_dir])\\n\\n print('Create the data generator.')\\n train_dl = DataLoader(config, data_type='train')\\n valid_dl = copy(train_dl)\\n valid_dl.data_type = 'valid'\\n\\n print('Create the model.')\\n modeler = LSTMChem(config, session='train')\\n\\n print('Create the trainer')\\n trainer = LSTMChemTrainer(modeler, train_dl, valid_dl)\\n\\n print('Start training the model.')\\n trainer.train()\\n\\n\\nif __name__ == '__main__':\\n main()\\n\\n\"" ], "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" } }, "metadata": {}, "execution_count": 44 } ], "source": [ "'''\n", "from copy import copy\n", "#from lstm_chem.utils.config import process_config\n", "#from lstm_chem.utils.dirs import create_dirs\n", "#from lstm_chem.data_loader import DataLoader\n", "#from lstm_chem.model import LSTMChem\n", "#from lstm_chem.trainer import LSTMChemTrainer\n", "\n", "CONFIG_FILE = '/content/base_config.json'\n", "\n", "\n", "def main():\n", " config = process_config(CONFIG_FILE)\n", "\n", " # create the experiments dirs\n", " create_dirs(\n", " [config.exp_dir, config.tensorboard_log_dir, config.checkpoint_dir])\n", "\n", " print('Create the data generator.')\n", " train_dl = DataLoader(config, data_type='train')\n", " valid_dl = copy(train_dl)\n", " valid_dl.data_type = 'valid'\n", "\n", " print('Create the model.')\n", " modeler = LSTMChem(config, session='train')\n", "\n", " print('Create the trainer')\n", " trainer = LSTMChemTrainer(modeler, train_dl, valid_dl)\n", "\n", " print('Start training the model.')\n", " trainer.train()\n", "\n", "\n", "if __name__ == '__main__':\n", " main()\n", "\n", "'''" ] }, { "cell_type": "code", "source": [], "metadata": { "id": "HfusVbkepFFc" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ " #!zip -r exp.zip experiments/" ], "metadata": { "id": "WYogVr1SYqIH" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "#from google.colab import files\n", "#files.download(\"/content/exp.zip\")" ], "metadata": { "id": "AD3Sb9cUaW4K" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "metadata": { "id": "u9o6Sxw14iN1" }, "source": [ "## Fine Tuning" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jiYkhYlk4wuc", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "2f4ee8a0-97ee-4e03-bc23-3cf6f610b067" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", "Requirement already satisfied: rdkit in /usr/local/lib/python3.8/dist-packages (2022.9.5)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.8/dist-packages (from rdkit) (1.22.4)\n", "Requirement already satisfied: Pillow in /usr/local/lib/python3.8/dist-packages (from rdkit) (8.4.0)\n" ] } ], "source": [ "pip install rdkit" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ESaAvgIV4gz8" }, "outputs": [], "source": [ "import numpy as np\n", "from sklearn.decomposition import PCA\n", "import matplotlib.pyplot as plt\n", "%matplotlib inline\n", "\n", "from rdkit import Chem, DataStructs\n", "from rdkit.Chem import AllChem, Draw\n", "from rdkit.Chem.Draw import IPythonConsole\n", "\n", "#from lstm_chem.utils.config import process_config\n", "#from lstm_chem.model import LSTMChem\n", "#from lstm_chem.finetuner import LSTMChemFinetuner\n", "#from lstm_chem.data_loader import DataLoader" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RPVwkm9CvpV-", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "72b02711-0176-4bef-9ecb-10c3503b271b" }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Loading model architecture from /content/drive/MyDrive/DDP FSR/experiments/2023-02-27/LSTM_Chem/model_arch.json ...\n", "Loading model checkpoint from /content/drive/MyDrive/DDP FSR/experiments/2023-02-27/LSTM_Chem/checkpoints/LSTM_Chem-04-0.44.hdf5 ...\n", "Loaded the Model.\n", "loading SMILES...\n", "done.\n", "tokenizing SMILES...\n" ] }, { "output_type": "stream", "name": "stderr", "text": [ "100%|██████████| 1567/1567 [00:03<00:00, 420.78it/s]\n" ] }, { "output_type": "stream", "name": "stdout", "text": [ "done.\n", "Epoch 1/15\n", "1567/1567 [==============================] - 30s 14ms/step - loss: 1.7070\n", "Epoch 2/15\n", "1567/1567 [==============================] - 18s 11ms/step - loss: 1.4273\n", "Epoch 3/15\n", "1567/1567 [==============================] - 17s 11ms/step - loss: 1.3239\n", "Epoch 4/15\n", "1567/1567 [==============================] - 18s 12ms/step - loss: 1.2488\n", "Epoch 5/15\n", "1567/1567 [==============================] - 19s 12ms/step - loss: 1.2020\n", "Epoch 6/15\n", "1567/1567 [==============================] - 18s 12ms/step - loss: 1.1559\n", "Epoch 7/15\n", "1567/1567 [==============================] - 18s 11ms/step - loss: 1.0998\n", "Epoch 8/15\n", "1567/1567 [==============================] - 19s 12ms/step - loss: 1.0606\n", "Epoch 9/15\n", "1567/1567 [==============================] - 18s 11ms/step - loss: 1.0335\n", "Epoch 10/15\n", "1567/1567 [==============================] - 18s 11ms/step - loss: 0.9963\n", "Epoch 11/15\n", "1567/1567 [==============================] - 19s 12ms/step - loss: 0.9668\n", "Epoch 12/15\n", "1567/1567 [==============================] - 18s 11ms/step - loss: 0.9427\n", "Epoch 13/15\n", "1567/1567 [==============================] - 18s 11ms/step - loss: 0.9174\n", "Epoch 14/15\n", "1567/1567 [==============================] - 21s 13ms/step - loss: 0.8950\n", "Epoch 15/15\n", "1567/1567 [==============================] - 18s 11ms/step - loss: 0.8713\n" ] }, { "output_type": "execute_result", "data": { "text/plain": [ "" ] }, "metadata": {}, "execution_count": 49 } ], "source": [ "config = process_config('/content/config.json')\n", "\n", "modeler = LSTMChem(config, session='finetune')\n", "finetune_dl = DataLoader(config, data_type='finetune')\n", "\n", "finetuner = LSTMChemFinetuner(modeler, finetune_dl)\n", "finetuner.finetune()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "CzkqEJsmvpRn" }, "outputs": [], "source": [ "finetuned_smiles = finetuner.sample(num=100)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NKkhyY_4-mS4", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "e2920866-79b5-4e26-89d5-2dfc696d3f27" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "['Cc1ccc2c(c1)C(C)C(CCCN(C)C)C1=C(C1)N2',\n", " 'CC(C)(C)C',\n", " 'CCCCCC',\n", " 'CN(C)CCON=C1c2ccccc2CCc2ccccc21',\n", " 'CC(C)(O)C(O)c1ccc(O)cc1',\n", " 'CCC',\n", " 'Cc1cc(O)ccc1C=C1[C]2CN(C)C2CC1=O',\n", " 'CCNc1cccc(C)c1C=Cc1ccccc1',\n", " 'CC(=O)c1cncn1CC1CCCCC1',\n", " 'CCCC(C)C1(C)C(=O)NC(=O)NC1=O',\n", " 'CCN(CC)C(=O)C1CCn2cnc1c2C(C)C',\n", " 'CC(C)C',\n", " 'C=C',\n", " 'NCCc1ccccc1',\n", " 'CNCCc1ccccn1',\n", " 'CN(C)CCON1c2ccccc2C=Cc2ccccc21',\n", " 'CN1C(=O)CC(c2ccccc2)C1=O',\n", " 'NC1=NC(=O)C1N(C(=O)C1=CC1)c1ccccc1',\n", " 'Cc1ccc2c(c1)C(c1ccccc1)=NCC(=O)N2C',\n", " 'NC(=O)NO',\n", " 'NC(=O)Nc1cccc(O)c1O',\n", " 'CN1C(=O)CCC1Cc1ccccc1',\n", " 'Cc1ccc2[nH]c(=O)nc(-c3ccccc3)c2c1',\n", " 'CCN1c2ccccc2Sc2ccccc21',\n", " 'CC(=C1CCC(=O)NC1=O)c1ccccc1',\n", " 'COc1ccc2c(c1)N(CC(C)CN(C)C)c1ccccc1S2',\n", " 'CCOC(=O)Oc1ccccc1C(C)C',\n", " 'C=C',\n", " 'CCCC(=O)C(=O)n1c2ccccc3sc3ccc(C)cc21',\n", " 'NC1OCC(c2ccccc2)=NC1=O',\n", " 'CCC',\n", " 'CC1CC(CNC(=O)c2cc(S(C)(=O)=O)ccc2N)N1C',\n", " 'C=C',\n", " 'CCc1ccccc1',\n", " 'CC1CCCCN1',\n", " 'O=C(O)Oc1cc2ccccc2s1',\n", " 'CCC(O)CC1(CC)C(=O)NC(=O)NC1=O',\n", " 'Cc1ccc2c(c1)N(C)CCC(C)C2=O',\n", " 'CC(C)(C)OC(=O)C(C)(C)c1ccccc1',\n", " 'CN(CCc1ccccc1)CCn1c(=O)[nH]c2ccccc21',\n", " 'CC(=O)NC(C)NNC(O)c1ccccc1',\n", " 'NC1CCCCCN1c1ccccc1',\n", " 'CCCN1CCN(c2cccc(C)c2C)CC1',\n", " 'C=C(C)OC(=O)O',\n", " 'O=C(O)C1CCN(c2ccccc2)C1',\n", " 'CN(C)CCC(N)N1c2ccccc2Sc2ccccc21',\n", " 'C[N+](C)(C)CCO',\n", " 'CCN(CC)CCOC(=O)c1ccc(NC2CCCCC2)cc1',\n", " 'CCC',\n", " 'NC(N)=O',\n", " 'CC(NN)c1ccccc1',\n", " 'CN1CCN(CCCCOC2c3ccccc3N(C)c3ccccc32)CC1',\n", " 'COc1cccc(C(O)(O)OC(C)C)c1',\n", " 'CN(C)CCC=C1c2ccccc2C(C)(C)c2ccccc21',\n", " 'CC1OC(c2cccc(Br)c2)CN(C)CCN1C',\n", " 'CCN1N=NNC(=O)C(c2ccccc2)=C1SS',\n", " 'Cc1cccc(C2=NCC(=O)N(C)c3ccc(C)cc32)c1',\n", " 'Cc1ccc2c(c1)N(CC(C)CN(C)C)c1ccccc1S2',\n", " 'CCCCCC',\n", " 'COC(=O)C1(c2ccccc2)NNC(=O)NC1=O',\n", " 'CC(C)C',\n", " 'Cc1ccc(CCN(C)C)cc1O',\n", " 'CCN1CCCC1CNC(=O)C1=C(C)C(C(O)N(CC)C(=O)OC(C)(C)O)=C1',\n", " 'CCOC(CC)CC(N)=O',\n", " 'CNS(=O)(=O)c1ccc2c(c1)N(CCC1CCN(CCOC)CC1)c1ccccc1S2',\n", " 'CN(C)c1cnccn1',\n", " 'CCC',\n", " 'CN1CCCC([C]([C]2CCCC2)c2ccccc2)CC1',\n", " 'CCC(=O)c1ccc(C)cc1',\n", " 'C=C',\n", " 'CC(=O)C1(O)CCCCC1',\n", " 'CN(C)c1cnccn1',\n", " 'Cc1ccc2c(c1)N(CCCN(C)C)c1ccccc1S2',\n", " 'CN1C(=S)CN=C(c2ccccc2)c2ccccc21',\n", " 'CN(C)CCC(=O)CCn1c2c(ccccc3ccccc31)S2',\n", " 'CC(N)C(=O)NNC(Br)c1ccncc1',\n", " 'CN(C)CCCc1c2c(nc3ccc(O)cc13)N2C',\n", " 'CC(=O)Nc1ccc(C)cc1',\n", " 'CC(C)C1CCC(=O)OC1(C)C',\n", " 'C=C',\n", " 'Cc1cccc(C)c1OCC1CCCN(C)C1',\n", " 'C=C',\n", " 'CCCCCCC',\n", " 'Cc1ccc2c(c1)C(c1ccccc1)=NCC(=O)N2C',\n", " 'Cc1ccc(C(=O)NCCN2CCN(c3ccccc3)CC2)cc1',\n", " 'CCn1ccc(=O)c(O)c1C',\n", " 'CCCCCCC',\n", " 'CCCCN1CCC(C(C)=O)CC1',\n", " 'CCN1C(=O)CN=C(CN2CCN(C)CC2)c2cc(Br)ccc21',\n", " 'CCOC(=O)c1ccc(O)cc1',\n", " 'CC(O)Cn1c2ccccc2n1C',\n", " 'CCC1(CC)C(=O)NC(=O)NC1=O',\n", " 'Cc1ccc2c(c1)C(c1ccccc1)=NCC(=S)N2CC(=O)O',\n", " 'Cc1ccc2c(c1)N(CC(C)CN(C)C)c1ccccc1S2',\n", " 'C1CCCCC1',\n", " 'NC(N)=Nc1nc(-c2ccccc2)cs1',\n", " 'Cc1cccc(C)c1NC1=NCC=C1',\n", " 'CN(C(=O)O)c1cc[nH]c(=O)c1-c1ccccc1',\n", " 'CN(C)CCc1ccccn1',\n", " 'Cc1cccc(C(=O)C(=O)N(C)C)c1']" ] }, "metadata": {}, "execution_count": 51 } ], "source": [ "finetuned_smiles\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "gn4vlKNZ5IYR" }, "outputs": [], "source": [ "with open('/content/bbbp.txt') as f:\n", " ksmiles = [l.rstrip() for l in f]\n", "kmols = [Chem.MolFromSmiles(smi) for smi in ksmiles]\n", "\n", "Kfps = []\n", "for mol in kmols:\n", " try:\n", " bv = AllChem.GetMACCSKeysFingerprint(mol)\n", " fp = np.zeros(len(bv))\n", " DataStructs.ConvertToNumpyArray(bv, fp)\n", " Kfps.append(fp)\n", " except:\n", " pass\n", "Klen = len(Kfps)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1WvcJGzt5SFF" }, "outputs": [], "source": [ "with open('/content/bbbp.txt') as f:\n", " fsmiles = [l.rstrip() for l in f]\n", "fmols = [Chem.MolFromSmiles(smi) for smi in fsmiles]\n", "\n", "Ffps, Fbvs = [], []\n", "for mol in fmols:\n", " try:\n", " bv = AllChem.GetMACCSKeysFingerprint(mol)\n", " Fbvs.append(bv)\n", " \n", " fp = np.zeros(len(bv))\n", " DataStructs.ConvertToNumpyArray(bv, fp)\n", " Ffps.append(fp)\n", " except:\n", " pass\n", "Flen = len(Ffps)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZGRvrTSL5aDL" }, "outputs": [], "source": [ "Sfps, Sbvs, smols = [], [], []\n", "for smi in finetuned_smiles:\n", " mol = Chem.MolFromSmiles(smi)\n", " smols.append(mol)\n", " \n", " bv = AllChem.GetMACCSKeysFingerprint(mol)\n", " Sbvs.append(bv)\n", " \n", " fp = np.zeros(len(bv))\n", " DataStructs.ConvertToNumpyArray(bv, fp)\n", " Sfps.append(fp)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "uRgKLu425dkb" }, "outputs": [], "source": [ "x = Kfps + Ffps + Sfps\n", "pca = PCA(n_components=2, random_state=71)\n", "X = pca.fit_transform(x)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "726JRpW35e51", "colab": { "base_uri": "https://localhost:8080/", "height": 334 }, "outputId": "203eb081-ca40-499b-edb3-2f048b4f33ff" }, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": { "needs_background": "light" } } ], "source": [ "plt.figure(figsize=(8, 5))\n", "plt.scatter(X[:Klen, 0], X[:Klen, 1],\n", " c='w', edgecolors='k', label='Known E. Coli inhibitors')\n", "#plt.scatter(X[Klen:Klen + Flen, 0], X[Klen:Klen + Flen, 1],\n", "# s=200, c='r', marker='2', edgecolors='k', label='Iteration 4 Molecules')\n", "plt.scatter(X[Klen + Flen:, 0], X[Klen + Flen:, 1],\n", " c='b', marker='+', label='Finetuned Generated')\n", "plt.xlabel('PC 1')\n", "plt.ylabel('PC 2')\n", "plt.legend();" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UhWJHYsJ5jC-" }, "outputs": [], "source": [ "idxs = []\n", "for Fbv in Fbvs:\n", " idx = np.argmax(DataStructs.BulkTanimotoSimilarity(Fbv, Sbvs))\n", " idxs.append(idx)\n", "nsmols = [smols[idx] for idx in idxs]" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "RDIXRY0o9yOd" }, "outputs": [], "source": [ "showmols = []\n", "for i, j in zip(fmols, nsmols):\n", " showmols.append(i)\n", " showmols.append(j)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-FCC59yc1fiP", "colab": { "base_uri": "https://localhost:8080/", "height": 1000 }, "outputId": "2a7697fb-a2d8-4f20-cde2-6387c1d939bd" }, "outputs": [ { "output_type": "stream", "name": "stderr", "text": [ "/usr/local/lib/python3.8/dist-packages/rdkit/Chem/Draw/IPythonConsole.py:258: UserWarning: Truncating the list of molecules to be displayed to 50. Change the maxMols value to display more.\n", " warnings.warn(\n" ] }, { "output_type": "execute_result", "data": { "image/png": "\n", "text/plain": [ "" ] }, "metadata": {}, "execution_count": 67 } ], "source": [ "\n", "Draw.MolsToGridImage(showmols, molsPerRow=2)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "aMz-y3NWoYfc" }, "outputs": [], "source": [ "with open('bbbp_gen_fsr.txt', 'w') as f:\n", " for line in finetuned_smiles:\n", " f.write(f\"{line}\\n\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "tclLlv8ot5zX", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "ec2dea5d-5c81-435e-998f-3ff0ec70b1f1" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ "100" ] }, "metadata": {}, "execution_count": 69 } ], "source": [ "len(finetuned_smiles)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "F399xqZfQjF2" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "provenance": [] }, "gpuClass": "standard", "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }