{ "cells": [ { "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ "# materials.smi-TED - INFERENCE (Regression)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Install extra packages for notebook\n", "%pip install seaborn xgboost" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import sys\n", "sys.path.append('../inference')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "# materials.smi-ted (smi-ted)\n", "from smi_ted_light.load import load_smi_ted\n", "\n", "# Data\n", "import torch\n", "import pandas as pd\n", "import numpy as np\n", "\n", "# Chemistry\n", "from rdkit import Chem\n", "from rdkit.Chem import PandasTools\n", "from rdkit.Chem import Descriptors\n", "PandasTools.RenderImagesInAllDataFrames(True)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# function to canonicalize SMILES\n", "def normalize_smiles(smi, canonical=True, isomeric=False):\n", " try:\n", " normalized = Chem.MolToSmiles(\n", " Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n", " )\n", " except:\n", " normalized = None\n", " return normalized" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Import smi-ted" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Random Seed: 12345\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Using Rotation Embedding\n", "Vocab size: 2393\n", "[INFERENCE MODE - smi-ted-Light]\n" ] } ], "source": [ "model_smi_ted = load_smi_ted(\n", " folder='../inference/smi_ted_light',\n", " ckpt_filename='smi-ted-Light_40.pt'\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Lipophilicity Dataset" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Experiments - Data Load" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "df_train = pd.read_csv(\"../finetune/moleculenet/lipophilicity/train.csv\")\n", "df_test = pd.read_csv(\"../finetune/moleculenet/lipophilicity/test.csv\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### SMILES canonization" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(3360, 3)\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>smiles</th>\n", " <th>y</th>\n", " <th>norm_smiles</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4</td>\n", " <td>0.814313</td>\n", " <td>Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12...</td>\n", " <td>0.446346</td>\n", " <td>COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3...</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5...</td>\n", " <td>1.148828</td>\n", " <td>CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c...</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13</td>\n", " <td>0.404532</td>\n", " <td>O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4...</td>\n", " <td>-0.164144</td>\n", " <td>O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)...</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " smiles y \\\n", "0 Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4 0.814313 \n", "1 COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12... 0.446346 \n", "2 CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5... 1.148828 \n", "3 Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13 0.404532 \n", "4 Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4... -0.164144 \n", "\n", " norm_smiles \n", "0 Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1 \n", "1 COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3... \n", "2 CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c... \n", "3 O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12 \n", "4 O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)... " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_train['norm_smiles'] = df_train['smiles'].apply(normalize_smiles)\n", "df_train_normalized = df_train.dropna()\n", "print(df_train_normalized.shape)\n", "df_train_normalized.head()" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(420, 3)\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>smiles</th>\n", " <th>y</th>\n", " <th>norm_smiles</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>N(c1ccccc1)c2ccnc3ccccc23</td>\n", " <td>0.488161</td>\n", " <td>c1ccc(Nc2ccnc3ccccc23)cc1</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1</td>\n", " <td>0.070017</td>\n", " <td>Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5</td>\n", " <td>-0.415030</td>\n", " <td>NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc...</td>\n", " <td>0.897942</td>\n", " <td>O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[...</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>NS(=O)(=O)c1nc2ccccc2s1</td>\n", " <td>-0.707731</td>\n", " <td>NS(=O)(=O)c1nc2ccccc2s1</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "</div>" ], "text/plain": [ " smiles y \\\n", "0 N(c1ccccc1)c2ccnc3ccccc23 0.488161 \n", "1 Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1 0.070017 \n", "2 NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5 -0.415030 \n", "3 OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc... 0.897942 \n", "4 NS(=O)(=O)c1nc2ccccc2s1 -0.707731 \n", "\n", " norm_smiles \n", "0 c1ccc(Nc2ccnc3ccccc23)cc1 \n", "1 Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2 \n", "2 NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1 \n", "3 O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[... \n", "4 NS(=O)(=O)c1nc2ccccc2s1 " ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_test['norm_smiles'] = df_test['smiles'].apply(normalize_smiles)\n", "df_test_normalized = df_test.dropna()\n", "print(df_test_normalized.shape)\n", "df_test_normalized.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Embeddings extraction " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### smi-ted embeddings extraction" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 33/33 [00:38<00:00, 1.15s/it]\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>0</th>\n", " <th>1</th>\n", " <th>2</th>\n", " <th>3</th>\n", " <th>4</th>\n", " <th>5</th>\n", " <th>6</th>\n", " <th>7</th>\n", " <th>8</th>\n", " <th>9</th>\n", " <th>...</th>\n", " <th>758</th>\n", " <th>759</th>\n", " <th>760</th>\n", " <th>761</th>\n", " <th>762</th>\n", " <th>763</th>\n", " <th>764</th>\n", " <th>765</th>\n", " <th>766</th>\n", " <th>767</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.367646</td>\n", " <td>-0.504889</td>\n", " <td>0.040485</td>\n", " <td>0.385314</td>\n", " <td>0.564923</td>\n", " <td>-0.684497</td>\n", " <td>1.160397</td>\n", " <td>0.071218</td>\n", " <td>0.799428</td>\n", " <td>0.181323</td>\n", " <td>...</td>\n", " <td>-1.379994</td>\n", " <td>-0.167221</td>\n", " <td>0.104886</td>\n", " <td>0.239571</td>\n", " <td>-0.744390</td>\n", " <td>0.590423</td>\n", " <td>-0.808946</td>\n", " <td>0.792584</td>\n", " <td>0.550898</td>\n", " <td>-0.176831</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0.455316</td>\n", " <td>-0.485554</td>\n", " <td>0.062206</td>\n", " <td>0.387994</td>\n", " <td>0.567590</td>\n", " <td>-0.713285</td>\n", " <td>1.144267</td>\n", " <td>-0.057046</td>\n", " <td>0.753016</td>\n", " <td>0.112180</td>\n", " <td>...</td>\n", " <td>-1.332142</td>\n", " <td>-0.096662</td>\n", " <td>0.221944</td>\n", " <td>0.327923</td>\n", " <td>-0.739358</td>\n", " <td>0.659803</td>\n", " <td>-0.775723</td>\n", " <td>0.745837</td>\n", " <td>0.566330</td>\n", " <td>-0.111946</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0.442309</td>\n", " <td>-0.484732</td>\n", " <td>0.084945</td>\n", " <td>0.384787</td>\n", " <td>0.564752</td>\n", " <td>-0.704130</td>\n", " <td>1.159491</td>\n", " <td>0.021168</td>\n", " <td>0.846539</td>\n", " <td>0.118463</td>\n", " <td>...</td>\n", " <td>-1.324177</td>\n", " <td>-0.110403</td>\n", " <td>0.207824</td>\n", " <td>0.281665</td>\n", " <td>-0.780818</td>\n", " <td>0.693484</td>\n", " <td>-0.832626</td>\n", " <td>0.763095</td>\n", " <td>0.532460</td>\n", " <td>-0.196708</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>0.527961</td>\n", " <td>-0.519151</td>\n", " <td>0.091635</td>\n", " <td>0.353518</td>\n", " <td>0.421795</td>\n", " <td>-0.724220</td>\n", " <td>1.093752</td>\n", " <td>0.148574</td>\n", " <td>0.804047</td>\n", " <td>0.194627</td>\n", " <td>...</td>\n", " <td>-1.358414</td>\n", " <td>-0.111483</td>\n", " <td>0.151692</td>\n", " <td>0.186741</td>\n", " <td>-0.601867</td>\n", " <td>0.641591</td>\n", " <td>-0.747422</td>\n", " <td>0.794239</td>\n", " <td>0.640765</td>\n", " <td>-0.239649</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>0.464432</td>\n", " <td>-0.511090</td>\n", " <td>0.038785</td>\n", " <td>0.346217</td>\n", " <td>0.492919</td>\n", " <td>-0.619387</td>\n", " <td>1.048157</td>\n", " <td>0.095910</td>\n", " <td>0.738604</td>\n", " <td>0.119270</td>\n", " <td>...</td>\n", " <td>-1.223927</td>\n", " <td>-0.109863</td>\n", " <td>0.151280</td>\n", " <td>0.244834</td>\n", " <td>-0.686610</td>\n", " <td>0.759327</td>\n", " <td>-0.756338</td>\n", " <td>0.766427</td>\n", " <td>0.610454</td>\n", " <td>-0.197345</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>5 rows × 768 columns</p>\n", "</div>" ], "text/plain": [ " 0 1 2 3 4 5 6 \\\n", "0 0.367646 -0.504889 0.040485 0.385314 0.564923 -0.684497 1.160397 \n", "1 0.455316 -0.485554 0.062206 0.387994 0.567590 -0.713285 1.144267 \n", "2 0.442309 -0.484732 0.084945 0.384787 0.564752 -0.704130 1.159491 \n", "3 0.527961 -0.519151 0.091635 0.353518 0.421795 -0.724220 1.093752 \n", "4 0.464432 -0.511090 0.038785 0.346217 0.492919 -0.619387 1.048157 \n", "\n", " 7 8 9 ... 758 759 760 761 \\\n", "0 0.071218 0.799428 0.181323 ... -1.379994 -0.167221 0.104886 0.239571 \n", "1 -0.057046 0.753016 0.112180 ... -1.332142 -0.096662 0.221944 0.327923 \n", "2 0.021168 0.846539 0.118463 ... -1.324177 -0.110403 0.207824 0.281665 \n", "3 0.148574 0.804047 0.194627 ... -1.358414 -0.111483 0.151692 0.186741 \n", "4 0.095910 0.738604 0.119270 ... -1.223927 -0.109863 0.151280 0.244834 \n", "\n", " 762 763 764 765 766 767 \n", "0 -0.744390 0.590423 -0.808946 0.792584 0.550898 -0.176831 \n", "1 -0.739358 0.659803 -0.775723 0.745837 0.566330 -0.111946 \n", "2 -0.780818 0.693484 -0.832626 0.763095 0.532460 -0.196708 \n", "3 -0.601867 0.641591 -0.747422 0.794239 0.640765 -0.239649 \n", "4 -0.686610 0.759327 -0.756338 0.766427 0.610454 -0.197345 \n", "\n", "[5 rows x 768 columns]" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.no_grad():\n", " df_embeddings_train = model_smi_ted.encode(df_train_normalized['norm_smiles'])\n", "df_embeddings_train.head()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "100%|██████████| 4/4 [00:05<00:00, 1.46s/it]\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>0</th>\n", " <th>1</th>\n", " <th>2</th>\n", " <th>3</th>\n", " <th>4</th>\n", " <th>5</th>\n", " <th>6</th>\n", " <th>7</th>\n", " <th>8</th>\n", " <th>9</th>\n", " <th>...</th>\n", " <th>758</th>\n", " <th>759</th>\n", " <th>760</th>\n", " <th>761</th>\n", " <th>762</th>\n", " <th>763</th>\n", " <th>764</th>\n", " <th>765</th>\n", " <th>766</th>\n", " <th>767</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>0.392252</td>\n", " <td>-0.504846</td>\n", " <td>0.056791</td>\n", " <td>0.356297</td>\n", " <td>0.475918</td>\n", " <td>-0.648899</td>\n", " <td>1.157862</td>\n", " <td>-0.022914</td>\n", " <td>0.703240</td>\n", " <td>0.192023</td>\n", " <td>...</td>\n", " <td>-1.208714</td>\n", " <td>-0.094441</td>\n", " <td>0.128845</td>\n", " <td>0.403995</td>\n", " <td>-0.782782</td>\n", " <td>0.541907</td>\n", " <td>-0.707272</td>\n", " <td>0.901041</td>\n", " <td>0.629461</td>\n", " <td>-0.020630</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>0.387422</td>\n", " <td>-0.481142</td>\n", " <td>0.049675</td>\n", " <td>0.353058</td>\n", " <td>0.601170</td>\n", " <td>-0.646099</td>\n", " <td>1.142392</td>\n", " <td>0.060092</td>\n", " <td>0.763799</td>\n", " <td>0.110331</td>\n", " <td>...</td>\n", " <td>-1.248282</td>\n", " <td>-0.139790</td>\n", " <td>0.075585</td>\n", " <td>0.202242</td>\n", " <td>-0.729794</td>\n", " <td>0.705914</td>\n", " <td>-0.771751</td>\n", " <td>0.843173</td>\n", " <td>0.618850</td>\n", " <td>-0.213584</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>0.390975</td>\n", " <td>-0.510056</td>\n", " <td>0.070656</td>\n", " <td>0.380695</td>\n", " <td>0.601486</td>\n", " <td>-0.595827</td>\n", " <td>1.182193</td>\n", " <td>0.011085</td>\n", " <td>0.688093</td>\n", " <td>0.056453</td>\n", " <td>...</td>\n", " <td>-1.294595</td>\n", " <td>-0.164846</td>\n", " <td>0.194435</td>\n", " <td>0.240742</td>\n", " <td>-0.773443</td>\n", " <td>0.608631</td>\n", " <td>-0.747181</td>\n", " <td>0.791911</td>\n", " <td>0.611874</td>\n", " <td>-0.125455</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>0.423924</td>\n", " <td>-0.557325</td>\n", " <td>0.083810</td>\n", " <td>0.328703</td>\n", " <td>0.399589</td>\n", " <td>-0.622818</td>\n", " <td>1.079945</td>\n", " <td>0.097611</td>\n", " <td>0.724030</td>\n", " <td>0.135976</td>\n", " <td>...</td>\n", " <td>-1.412060</td>\n", " <td>-0.106541</td>\n", " <td>0.153314</td>\n", " <td>0.209962</td>\n", " <td>-0.699690</td>\n", " <td>0.648061</td>\n", " <td>-0.716241</td>\n", " <td>0.757986</td>\n", " <td>0.615963</td>\n", " <td>-0.258693</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>0.335576</td>\n", " <td>-0.559591</td>\n", " <td>0.119437</td>\n", " <td>0.364141</td>\n", " <td>0.375474</td>\n", " <td>-0.639833</td>\n", " <td>1.144707</td>\n", " <td>0.077512</td>\n", " <td>0.791759</td>\n", " <td>0.164201</td>\n", " <td>...</td>\n", " <td>-1.279041</td>\n", " <td>-0.186733</td>\n", " <td>0.106963</td>\n", " <td>0.254949</td>\n", " <td>-0.651694</td>\n", " <td>0.594167</td>\n", " <td>-0.680426</td>\n", " <td>0.887482</td>\n", " <td>0.651587</td>\n", " <td>-0.144996</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>5 rows × 768 columns</p>\n", "</div>" ], "text/plain": [ " 0 1 2 3 4 5 6 \\\n", "0 0.392252 -0.504846 0.056791 0.356297 0.475918 -0.648899 1.157862 \n", "1 0.387422 -0.481142 0.049675 0.353058 0.601170 -0.646099 1.142392 \n", "2 0.390975 -0.510056 0.070656 0.380695 0.601486 -0.595827 1.182193 \n", "3 0.423924 -0.557325 0.083810 0.328703 0.399589 -0.622818 1.079945 \n", "4 0.335576 -0.559591 0.119437 0.364141 0.375474 -0.639833 1.144707 \n", "\n", " 7 8 9 ... 758 759 760 761 \\\n", "0 -0.022914 0.703240 0.192023 ... -1.208714 -0.094441 0.128845 0.403995 \n", "1 0.060092 0.763799 0.110331 ... -1.248282 -0.139790 0.075585 0.202242 \n", "2 0.011085 0.688093 0.056453 ... -1.294595 -0.164846 0.194435 0.240742 \n", "3 0.097611 0.724030 0.135976 ... -1.412060 -0.106541 0.153314 0.209962 \n", "4 0.077512 0.791759 0.164201 ... -1.279041 -0.186733 0.106963 0.254949 \n", "\n", " 762 763 764 765 766 767 \n", "0 -0.782782 0.541907 -0.707272 0.901041 0.629461 -0.020630 \n", "1 -0.729794 0.705914 -0.771751 0.843173 0.618850 -0.213584 \n", "2 -0.773443 0.608631 -0.747181 0.791911 0.611874 -0.125455 \n", "3 -0.699690 0.648061 -0.716241 0.757986 0.615963 -0.258693 \n", "4 -0.651694 0.594167 -0.680426 0.887482 0.651587 -0.144996 \n", "\n", "[5 rows x 768 columns]" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "with torch.no_grad():\n", " df_embeddings_test = model_smi_ted.encode(df_test_normalized['norm_smiles'])\n", "df_embeddings_test.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Experiments - Lipophilicity prediction using smi-ted latent spaces" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### XGBoost prediction using the whole Latent Space" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "from xgboost import XGBRegressor\n", "from sklearn.metrics import mean_squared_error" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/html": [ "<style>#sk-container-id-1 {\n", " /* Definition of color scheme common for light and dark mode */\n", " --sklearn-color-text: black;\n", " --sklearn-color-line: gray;\n", " /* Definition of color scheme for unfitted estimators */\n", " --sklearn-color-unfitted-level-0: #fff5e6;\n", " --sklearn-color-unfitted-level-1: #f6e4d2;\n", " --sklearn-color-unfitted-level-2: #ffe0b3;\n", " --sklearn-color-unfitted-level-3: chocolate;\n", " /* Definition of color scheme for fitted estimators */\n", " --sklearn-color-fitted-level-0: #f0f8ff;\n", " --sklearn-color-fitted-level-1: #d4ebff;\n", " --sklearn-color-fitted-level-2: #b3dbfd;\n", " --sklearn-color-fitted-level-3: cornflowerblue;\n", "\n", " /* Specific color for light theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n", " --sklearn-color-icon: #696969;\n", "\n", " @media (prefers-color-scheme: dark) {\n", " /* Redefinition of color scheme for dark theme */\n", " --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n", " --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n", " --sklearn-color-icon: #878787;\n", " }\n", "}\n", "\n", "#sk-container-id-1 {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "#sk-container-id-1 pre {\n", " padding: 0;\n", "}\n", "\n", "#sk-container-id-1 input.sk-hidden--visually {\n", " border: 0;\n", " clip: rect(1px 1px 1px 1px);\n", " clip: rect(1px, 1px, 1px, 1px);\n", " height: 1px;\n", " margin: -1px;\n", " overflow: hidden;\n", " padding: 0;\n", " position: absolute;\n", " width: 1px;\n", "}\n", "\n", "#sk-container-id-1 div.sk-dashed-wrapped {\n", " border: 1px dashed var(--sklearn-color-line);\n", " margin: 0 0.4em 0.5em 0.4em;\n", " box-sizing: border-box;\n", " padding-bottom: 0.4em;\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "#sk-container-id-1 div.sk-container {\n", " /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n", " but bootstrap.min.css set `[hidden] { display: none !important; }`\n", " so we also need the `!important` here to be able to override the\n", " default hidden behavior on the sphinx rendered scikit-learn.org.\n", " See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n", " display: inline-block !important;\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-1 div.sk-text-repr-fallback {\n", " display: none;\n", "}\n", "\n", "div.sk-parallel-item,\n", "div.sk-serial,\n", "div.sk-item {\n", " /* draw centered vertical line to link estimators */\n", " background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n", " background-size: 2px 100%;\n", " background-repeat: no-repeat;\n", " background-position: center center;\n", "}\n", "\n", "/* Parallel-specific style estimator block */\n", "\n", "#sk-container-id-1 div.sk-parallel-item::after {\n", " content: \"\";\n", " width: 100%;\n", " border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n", " flex-grow: 1;\n", "}\n", "\n", "#sk-container-id-1 div.sk-parallel {\n", " display: flex;\n", " align-items: stretch;\n", " justify-content: center;\n", " background-color: var(--sklearn-color-background);\n", " position: relative;\n", "}\n", "\n", "#sk-container-id-1 div.sk-parallel-item {\n", " display: flex;\n", " flex-direction: column;\n", "}\n", "\n", "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n", " align-self: flex-end;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n", " align-self: flex-start;\n", " width: 50%;\n", "}\n", "\n", "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n", " width: 0;\n", "}\n", "\n", "/* Serial-specific style estimator block */\n", "\n", "#sk-container-id-1 div.sk-serial {\n", " display: flex;\n", " flex-direction: column;\n", " align-items: center;\n", " background-color: var(--sklearn-color-background);\n", " padding-right: 1em;\n", " padding-left: 1em;\n", "}\n", "\n", "\n", "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n", "clickable and can be expanded/collapsed.\n", "- Pipeline and ColumnTransformer use this feature and define the default style\n", "- Estimators will overwrite some part of the style using the `sk-estimator` class\n", "*/\n", "\n", "/* Pipeline and ColumnTransformer style (default) */\n", "\n", "#sk-container-id-1 div.sk-toggleable {\n", " /* Default theme specific background. It is overwritten whether we have a\n", " specific estimator or a Pipeline/ColumnTransformer */\n", " background-color: var(--sklearn-color-background);\n", "}\n", "\n", "/* Toggleable label */\n", "#sk-container-id-1 label.sk-toggleable__label {\n", " cursor: pointer;\n", " display: block;\n", " width: 100%;\n", " margin-bottom: 0;\n", " padding: 0.5em;\n", " box-sizing: border-box;\n", " text-align: center;\n", "}\n", "\n", "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n", " /* Arrow on the left of the label */\n", " content: \"▸\";\n", " float: left;\n", " margin-right: 0.25em;\n", " color: var(--sklearn-color-icon);\n", "}\n", "\n", "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n", " color: var(--sklearn-color-text);\n", "}\n", "\n", "/* Toggleable content - dropdown */\n", "\n", "#sk-container-id-1 div.sk-toggleable__content {\n", " max-height: 0;\n", " max-width: 0;\n", " overflow: hidden;\n", " text-align: left;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-1 div.sk-toggleable__content.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-1 div.sk-toggleable__content pre {\n", " margin: 0.2em;\n", " border-radius: 0.25em;\n", " color: var(--sklearn-color-text);\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n", " /* Expand drop-down */\n", " max-height: 200px;\n", " max-width: 100%;\n", " overflow: auto;\n", "}\n", "\n", "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n", " content: \"▾\";\n", "}\n", "\n", "/* Pipeline/ColumnTransformer-specific style */\n", "\n", "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator-specific style */\n", "\n", "/* Colorize estimator box */\n", "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n", "#sk-container-id-1 div.sk-label label {\n", " /* The background is the default theme color */\n", " color: var(--sklearn-color-text-on-default-background);\n", "}\n", "\n", "/* On hover, darken the color of the background */\n", "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "/* Label box, darken color on hover, fitted */\n", "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n", " color: var(--sklearn-color-text);\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Estimator label */\n", "\n", "#sk-container-id-1 div.sk-label label {\n", " font-family: monospace;\n", " font-weight: bold;\n", " display: inline-block;\n", " line-height: 1.2em;\n", "}\n", "\n", "#sk-container-id-1 div.sk-label-container {\n", " text-align: center;\n", "}\n", "\n", "/* Estimator-specific */\n", "#sk-container-id-1 div.sk-estimator {\n", " font-family: monospace;\n", " border: 1px dotted var(--sklearn-color-border-box);\n", " border-radius: 0.25em;\n", " box-sizing: border-box;\n", " margin-bottom: 0.5em;\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-0);\n", "}\n", "\n", "#sk-container-id-1 div.sk-estimator.fitted {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-0);\n", "}\n", "\n", "/* on hover */\n", "#sk-container-id-1 div.sk-estimator:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-2);\n", "}\n", "\n", "#sk-container-id-1 div.sk-estimator.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-2);\n", "}\n", "\n", "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n", "\n", "/* Common style for \"i\" and \"?\" */\n", "\n", ".sk-estimator-doc-link,\n", "a:link.sk-estimator-doc-link,\n", "a:visited.sk-estimator-doc-link {\n", " float: right;\n", " font-size: smaller;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1em;\n", " height: 1em;\n", " width: 1em;\n", " text-decoration: none !important;\n", " margin-left: 1ex;\n", " /* unfitted */\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-unfitted-level-1);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted,\n", "a:link.sk-estimator-doc-link.fitted,\n", "a:visited.sk-estimator-doc-link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n", ".sk-estimator-doc-link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover,\n", "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n", ".sk-estimator-doc-link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "/* Span, style for the box shown on hovering the info icon */\n", ".sk-estimator-doc-link span {\n", " display: none;\n", " z-index: 9999;\n", " position: relative;\n", " font-weight: normal;\n", " right: .2ex;\n", " padding: .5ex;\n", " margin: .5ex;\n", " width: min-content;\n", " min-width: 20ex;\n", " max-width: 50ex;\n", " color: var(--sklearn-color-text);\n", " box-shadow: 2pt 2pt 4pt #999;\n", " /* unfitted */\n", " background: var(--sklearn-color-unfitted-level-0);\n", " border: .5pt solid var(--sklearn-color-unfitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link.fitted span {\n", " /* fitted */\n", " background: var(--sklearn-color-fitted-level-0);\n", " border: var(--sklearn-color-fitted-level-3);\n", "}\n", "\n", ".sk-estimator-doc-link:hover span {\n", " display: block;\n", "}\n", "\n", "/* \"?\"-specific style due to the `<a>` HTML tag */\n", "\n", "#sk-container-id-1 a.estimator_doc_link {\n", " float: right;\n", " font-size: 1rem;\n", " line-height: 1em;\n", " font-family: monospace;\n", " background-color: var(--sklearn-color-background);\n", " border-radius: 1rem;\n", " height: 1rem;\n", " width: 1rem;\n", " text-decoration: none;\n", " /* unfitted */\n", " color: var(--sklearn-color-unfitted-level-1);\n", " border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n", "}\n", "\n", "#sk-container-id-1 a.estimator_doc_link.fitted {\n", " /* fitted */\n", " border: var(--sklearn-color-fitted-level-1) 1pt solid;\n", " color: var(--sklearn-color-fitted-level-1);\n", "}\n", "\n", "/* On hover */\n", "#sk-container-id-1 a.estimator_doc_link:hover {\n", " /* unfitted */\n", " background-color: var(--sklearn-color-unfitted-level-3);\n", " color: var(--sklearn-color-background);\n", " text-decoration: none;\n", "}\n", "\n", "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n", " /* fitted */\n", " background-color: var(--sklearn-color-fitted-level-3);\n", "}\n", "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>XGBRegressor(base_score=None, booster=None, callbacks=None,\n", " colsample_bylevel=None, colsample_bynode=None,\n", " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", " enable_categorical=False, eval_metric=None, feature_types=None,\n", " gamma=None, grow_policy=None, importance_type=None,\n", " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n", " max_cat_threshold=None, max_cat_to_onehot=None,\n", " max_delta_step=None, max_depth=4, max_leaves=None,\n", " min_child_weight=None, missing=nan, monotone_constraints=None,\n", " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", " num_parallel_tree=None, random_state=None, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\"> XGBRegressor<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>XGBRegressor(base_score=None, booster=None, callbacks=None,\n", " colsample_bylevel=None, colsample_bynode=None,\n", " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", " enable_categorical=False, eval_metric=None, feature_types=None,\n", " gamma=None, grow_policy=None, importance_type=None,\n", " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n", " max_cat_threshold=None, max_cat_to_onehot=None,\n", " max_delta_step=None, max_depth=4, max_leaves=None,\n", " min_child_weight=None, missing=nan, monotone_constraints=None,\n", " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", " num_parallel_tree=None, random_state=None, ...)</pre></div> </div></div></div></div>" ], "text/plain": [ "XGBRegressor(base_score=None, booster=None, callbacks=None,\n", " colsample_bylevel=None, colsample_bynode=None,\n", " colsample_bytree=None, device=None, early_stopping_rounds=None,\n", " enable_categorical=False, eval_metric=None, feature_types=None,\n", " gamma=None, grow_policy=None, importance_type=None,\n", " interaction_constraints=None, learning_rate=0.05, max_bin=None,\n", " max_cat_threshold=None, max_cat_to_onehot=None,\n", " max_delta_step=None, max_depth=4, max_leaves=None,\n", " min_child_weight=None, missing=nan, monotone_constraints=None,\n", " multi_strategy=None, n_estimators=2000, n_jobs=None,\n", " num_parallel_tree=None, random_state=None, ...)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "xgb_predict = XGBRegressor(n_estimators=2000, learning_rate=0.05, max_depth=4)\n", "xgb_predict.fit(df_embeddings_train, df_train_normalized['y'])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "# get XGBoost predictions\n", "y_pred = xgb_predict.predict(df_embeddings_test)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "RMSE Score: 0.6485\n" ] } ], "source": [ "rmse = np.sqrt(mean_squared_error(df_test_normalized[\"y\"], y_pred))\n", "print(f\"RMSE Score: {rmse:.4f}\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 2 }