{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# materials.smi-TED - INFERENCE (Regression)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Install extra packages for notebook\n",
    "%pip install seaborn xgboost"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('../inference')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# materials.smi-ted (smi-ted)\n",
    "from smi_ted_light.load import load_smi_ted\n",
    "\n",
    "# Data\n",
    "import torch\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "\n",
    "# Chemistry\n",
    "from rdkit import Chem\n",
    "from rdkit.Chem import PandasTools\n",
    "from rdkit.Chem import Descriptors\n",
    "PandasTools.RenderImagesInAllDataFrames(True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# function to canonicalize SMILES\n",
    "def normalize_smiles(smi, canonical=True, isomeric=False):\n",
    "    try:\n",
    "        normalized = Chem.MolToSmiles(\n",
    "            Chem.MolFromSmiles(smi), canonical=canonical, isomericSmiles=isomeric\n",
    "        )\n",
    "    except:\n",
    "        normalized = None\n",
    "    return normalized"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Import smi-ted"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Random Seed: 12345\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Using Rotation Embedding\n",
      "Vocab size: 2393\n",
      "[INFERENCE MODE - smi-ted-Light]\n"
     ]
    }
   ],
   "source": [
    "model_smi_ted = load_smi_ted(\n",
    "    folder='../inference/smi_ted_light',\n",
    "    ckpt_filename='smi-ted-Light_40.pt'\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Lipophilicity Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experiments - Data Load"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "df_train = pd.read_csv(\"../finetune/moleculenet/lipophilicity/train.csv\")\n",
    "df_test = pd.read_csv(\"../finetune/moleculenet/lipophilicity/test.csv\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### SMILES canonization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(3360, 3)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>y</th>\n",
       "      <th>norm_smiles</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4</td>\n",
       "      <td>0.814313</td>\n",
       "      <td>Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12...</td>\n",
       "      <td>0.446346</td>\n",
       "      <td>COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5...</td>\n",
       "      <td>1.148828</td>\n",
       "      <td>CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13</td>\n",
       "      <td>0.404532</td>\n",
       "      <td>O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4...</td>\n",
       "      <td>-0.164144</td>\n",
       "      <td>O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)...</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              smiles         y  \\\n",
       "0             Nc1ncnc2c1c(COc3cccc(Cl)c3)nn2C4CCOCC4  0.814313   \n",
       "1  COc1cc(cc2cnc(Nc3ccc(cc3)[C@@H](C)NC(=O)C)nc12...  0.446346   \n",
       "2  CC(=O)Nc1ccc2ccn(c3cc(Nc4ccn(C)n4)n5ncc(C#N)c5...  1.148828   \n",
       "3   Oc1ccc(CCNCCS(=O)(=O)CCCOCCSc2ccccc2)c3sc(O)nc13  0.404532   \n",
       "4  Clc1ccc2C(=O)C3=C(Nc2c1)C(=O)NN(Cc4cc5ccccc5s4... -0.164144   \n",
       "\n",
       "                                         norm_smiles  \n",
       "0             Nc1ncnc2c1c(COc1cccc(Cl)c1)nn2C1CCOCC1  \n",
       "1  COc1cc(-c2ccncc2)cc2cnc(Nc3ccc(C(C)NC(C)=O)cc3...  \n",
       "2  CC(=O)Nc1ccc2ccn(-c3cc(Nc4ccn(C)n4)n4ncc(C#N)c...  \n",
       "3   O=S(=O)(CCCOCCSc1ccccc1)CCNCCc1ccc(O)c2nc(O)sc12  \n",
       "4  O=c1[nH]n(Cc2cc3ccccc3s2)c(=O)c2c(=O)c3ccc(Cl)...  "
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_train['norm_smiles'] = df_train['smiles'].apply(normalize_smiles)\n",
    "df_train_normalized = df_train.dropna()\n",
    "print(df_train_normalized.shape)\n",
    "df_train_normalized.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(420, 3)\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>smiles</th>\n",
       "      <th>y</th>\n",
       "      <th>norm_smiles</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>N(c1ccccc1)c2ccnc3ccccc23</td>\n",
       "      <td>0.488161</td>\n",
       "      <td>c1ccc(Nc2ccnc3ccccc23)cc1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1</td>\n",
       "      <td>0.070017</td>\n",
       "      <td>Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5</td>\n",
       "      <td>-0.415030</td>\n",
       "      <td>NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc...</td>\n",
       "      <td>0.897942</td>\n",
       "      <td>O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>NS(=O)(=O)c1nc2ccccc2s1</td>\n",
       "      <td>-0.707731</td>\n",
       "      <td>NS(=O)(=O)c1nc2ccccc2s1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                              smiles         y  \\\n",
       "0                          N(c1ccccc1)c2ccnc3ccccc23  0.488161   \n",
       "1                 Clc1ccc2Oc3ccccc3N=C(N4CCNCC4)c2c1  0.070017   \n",
       "2          NC1(CCC1)c2ccc(cc2)c3ncc4cccnc4c3c5ccccc5 -0.415030   \n",
       "3  OC[C@H](O)CN1C(=O)[C@@H](Cc2ccccc12)NC(=O)c3cc...  0.897942   \n",
       "4                            NS(=O)(=O)c1nc2ccccc2s1 -0.707731   \n",
       "\n",
       "                                         norm_smiles  \n",
       "0                          c1ccc(Nc2ccnc3ccccc23)cc1  \n",
       "1               Clc1ccc2c(c1)C(N1CCNCC1)=Nc1ccccc1O2  \n",
       "2        NC1(c2ccc(-c3ncc4cccnc4c3-c3ccccc3)cc2)CCC1  \n",
       "3  O=C(NC1Cc2ccccc2N(CC(O)CO)C1=O)c1cc2cc(Cl)sc2[...  \n",
       "4                            NS(=O)(=O)c1nc2ccccc2s1  "
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_test['norm_smiles'] = df_test['smiles'].apply(normalize_smiles)\n",
    "df_test_normalized = df_test.dropna()\n",
    "print(df_test_normalized.shape)\n",
    "df_test_normalized.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Embeddings extraction "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### smi-ted embeddings extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 33/33 [00:38<00:00,  1.15s/it]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>758</th>\n",
       "      <th>759</th>\n",
       "      <th>760</th>\n",
       "      <th>761</th>\n",
       "      <th>762</th>\n",
       "      <th>763</th>\n",
       "      <th>764</th>\n",
       "      <th>765</th>\n",
       "      <th>766</th>\n",
       "      <th>767</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.367646</td>\n",
       "      <td>-0.504889</td>\n",
       "      <td>0.040485</td>\n",
       "      <td>0.385314</td>\n",
       "      <td>0.564923</td>\n",
       "      <td>-0.684497</td>\n",
       "      <td>1.160397</td>\n",
       "      <td>0.071218</td>\n",
       "      <td>0.799428</td>\n",
       "      <td>0.181323</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.379994</td>\n",
       "      <td>-0.167221</td>\n",
       "      <td>0.104886</td>\n",
       "      <td>0.239571</td>\n",
       "      <td>-0.744390</td>\n",
       "      <td>0.590423</td>\n",
       "      <td>-0.808946</td>\n",
       "      <td>0.792584</td>\n",
       "      <td>0.550898</td>\n",
       "      <td>-0.176831</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.455316</td>\n",
       "      <td>-0.485554</td>\n",
       "      <td>0.062206</td>\n",
       "      <td>0.387994</td>\n",
       "      <td>0.567590</td>\n",
       "      <td>-0.713285</td>\n",
       "      <td>1.144267</td>\n",
       "      <td>-0.057046</td>\n",
       "      <td>0.753016</td>\n",
       "      <td>0.112180</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.332142</td>\n",
       "      <td>-0.096662</td>\n",
       "      <td>0.221944</td>\n",
       "      <td>0.327923</td>\n",
       "      <td>-0.739358</td>\n",
       "      <td>0.659803</td>\n",
       "      <td>-0.775723</td>\n",
       "      <td>0.745837</td>\n",
       "      <td>0.566330</td>\n",
       "      <td>-0.111946</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.442309</td>\n",
       "      <td>-0.484732</td>\n",
       "      <td>0.084945</td>\n",
       "      <td>0.384787</td>\n",
       "      <td>0.564752</td>\n",
       "      <td>-0.704130</td>\n",
       "      <td>1.159491</td>\n",
       "      <td>0.021168</td>\n",
       "      <td>0.846539</td>\n",
       "      <td>0.118463</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.324177</td>\n",
       "      <td>-0.110403</td>\n",
       "      <td>0.207824</td>\n",
       "      <td>0.281665</td>\n",
       "      <td>-0.780818</td>\n",
       "      <td>0.693484</td>\n",
       "      <td>-0.832626</td>\n",
       "      <td>0.763095</td>\n",
       "      <td>0.532460</td>\n",
       "      <td>-0.196708</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.527961</td>\n",
       "      <td>-0.519151</td>\n",
       "      <td>0.091635</td>\n",
       "      <td>0.353518</td>\n",
       "      <td>0.421795</td>\n",
       "      <td>-0.724220</td>\n",
       "      <td>1.093752</td>\n",
       "      <td>0.148574</td>\n",
       "      <td>0.804047</td>\n",
       "      <td>0.194627</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.358414</td>\n",
       "      <td>-0.111483</td>\n",
       "      <td>0.151692</td>\n",
       "      <td>0.186741</td>\n",
       "      <td>-0.601867</td>\n",
       "      <td>0.641591</td>\n",
       "      <td>-0.747422</td>\n",
       "      <td>0.794239</td>\n",
       "      <td>0.640765</td>\n",
       "      <td>-0.239649</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.464432</td>\n",
       "      <td>-0.511090</td>\n",
       "      <td>0.038785</td>\n",
       "      <td>0.346217</td>\n",
       "      <td>0.492919</td>\n",
       "      <td>-0.619387</td>\n",
       "      <td>1.048157</td>\n",
       "      <td>0.095910</td>\n",
       "      <td>0.738604</td>\n",
       "      <td>0.119270</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.223927</td>\n",
       "      <td>-0.109863</td>\n",
       "      <td>0.151280</td>\n",
       "      <td>0.244834</td>\n",
       "      <td>-0.686610</td>\n",
       "      <td>0.759327</td>\n",
       "      <td>-0.756338</td>\n",
       "      <td>0.766427</td>\n",
       "      <td>0.610454</td>\n",
       "      <td>-0.197345</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 768 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        0         1         2         3         4         5         6    \\\n",
       "0  0.367646 -0.504889  0.040485  0.385314  0.564923 -0.684497  1.160397   \n",
       "1  0.455316 -0.485554  0.062206  0.387994  0.567590 -0.713285  1.144267   \n",
       "2  0.442309 -0.484732  0.084945  0.384787  0.564752 -0.704130  1.159491   \n",
       "3  0.527961 -0.519151  0.091635  0.353518  0.421795 -0.724220  1.093752   \n",
       "4  0.464432 -0.511090  0.038785  0.346217  0.492919 -0.619387  1.048157   \n",
       "\n",
       "        7         8         9    ...       758       759       760       761  \\\n",
       "0  0.071218  0.799428  0.181323  ... -1.379994 -0.167221  0.104886  0.239571   \n",
       "1 -0.057046  0.753016  0.112180  ... -1.332142 -0.096662  0.221944  0.327923   \n",
       "2  0.021168  0.846539  0.118463  ... -1.324177 -0.110403  0.207824  0.281665   \n",
       "3  0.148574  0.804047  0.194627  ... -1.358414 -0.111483  0.151692  0.186741   \n",
       "4  0.095910  0.738604  0.119270  ... -1.223927 -0.109863  0.151280  0.244834   \n",
       "\n",
       "        762       763       764       765       766       767  \n",
       "0 -0.744390  0.590423 -0.808946  0.792584  0.550898 -0.176831  \n",
       "1 -0.739358  0.659803 -0.775723  0.745837  0.566330 -0.111946  \n",
       "2 -0.780818  0.693484 -0.832626  0.763095  0.532460 -0.196708  \n",
       "3 -0.601867  0.641591 -0.747422  0.794239  0.640765 -0.239649  \n",
       "4 -0.686610  0.759327 -0.756338  0.766427  0.610454 -0.197345  \n",
       "\n",
       "[5 rows x 768 columns]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    df_embeddings_train = model_smi_ted.encode(df_train_normalized['norm_smiles'])\n",
    "df_embeddings_train.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 4/4 [00:05<00:00,  1.46s/it]\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "      <th>3</th>\n",
       "      <th>4</th>\n",
       "      <th>5</th>\n",
       "      <th>6</th>\n",
       "      <th>7</th>\n",
       "      <th>8</th>\n",
       "      <th>9</th>\n",
       "      <th>...</th>\n",
       "      <th>758</th>\n",
       "      <th>759</th>\n",
       "      <th>760</th>\n",
       "      <th>761</th>\n",
       "      <th>762</th>\n",
       "      <th>763</th>\n",
       "      <th>764</th>\n",
       "      <th>765</th>\n",
       "      <th>766</th>\n",
       "      <th>767</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0.392252</td>\n",
       "      <td>-0.504846</td>\n",
       "      <td>0.056791</td>\n",
       "      <td>0.356297</td>\n",
       "      <td>0.475918</td>\n",
       "      <td>-0.648899</td>\n",
       "      <td>1.157862</td>\n",
       "      <td>-0.022914</td>\n",
       "      <td>0.703240</td>\n",
       "      <td>0.192023</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.208714</td>\n",
       "      <td>-0.094441</td>\n",
       "      <td>0.128845</td>\n",
       "      <td>0.403995</td>\n",
       "      <td>-0.782782</td>\n",
       "      <td>0.541907</td>\n",
       "      <td>-0.707272</td>\n",
       "      <td>0.901041</td>\n",
       "      <td>0.629461</td>\n",
       "      <td>-0.020630</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0.387422</td>\n",
       "      <td>-0.481142</td>\n",
       "      <td>0.049675</td>\n",
       "      <td>0.353058</td>\n",
       "      <td>0.601170</td>\n",
       "      <td>-0.646099</td>\n",
       "      <td>1.142392</td>\n",
       "      <td>0.060092</td>\n",
       "      <td>0.763799</td>\n",
       "      <td>0.110331</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.248282</td>\n",
       "      <td>-0.139790</td>\n",
       "      <td>0.075585</td>\n",
       "      <td>0.202242</td>\n",
       "      <td>-0.729794</td>\n",
       "      <td>0.705914</td>\n",
       "      <td>-0.771751</td>\n",
       "      <td>0.843173</td>\n",
       "      <td>0.618850</td>\n",
       "      <td>-0.213584</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0.390975</td>\n",
       "      <td>-0.510056</td>\n",
       "      <td>0.070656</td>\n",
       "      <td>0.380695</td>\n",
       "      <td>0.601486</td>\n",
       "      <td>-0.595827</td>\n",
       "      <td>1.182193</td>\n",
       "      <td>0.011085</td>\n",
       "      <td>0.688093</td>\n",
       "      <td>0.056453</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.294595</td>\n",
       "      <td>-0.164846</td>\n",
       "      <td>0.194435</td>\n",
       "      <td>0.240742</td>\n",
       "      <td>-0.773443</td>\n",
       "      <td>0.608631</td>\n",
       "      <td>-0.747181</td>\n",
       "      <td>0.791911</td>\n",
       "      <td>0.611874</td>\n",
       "      <td>-0.125455</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0.423924</td>\n",
       "      <td>-0.557325</td>\n",
       "      <td>0.083810</td>\n",
       "      <td>0.328703</td>\n",
       "      <td>0.399589</td>\n",
       "      <td>-0.622818</td>\n",
       "      <td>1.079945</td>\n",
       "      <td>0.097611</td>\n",
       "      <td>0.724030</td>\n",
       "      <td>0.135976</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.412060</td>\n",
       "      <td>-0.106541</td>\n",
       "      <td>0.153314</td>\n",
       "      <td>0.209962</td>\n",
       "      <td>-0.699690</td>\n",
       "      <td>0.648061</td>\n",
       "      <td>-0.716241</td>\n",
       "      <td>0.757986</td>\n",
       "      <td>0.615963</td>\n",
       "      <td>-0.258693</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0.335576</td>\n",
       "      <td>-0.559591</td>\n",
       "      <td>0.119437</td>\n",
       "      <td>0.364141</td>\n",
       "      <td>0.375474</td>\n",
       "      <td>-0.639833</td>\n",
       "      <td>1.144707</td>\n",
       "      <td>0.077512</td>\n",
       "      <td>0.791759</td>\n",
       "      <td>0.164201</td>\n",
       "      <td>...</td>\n",
       "      <td>-1.279041</td>\n",
       "      <td>-0.186733</td>\n",
       "      <td>0.106963</td>\n",
       "      <td>0.254949</td>\n",
       "      <td>-0.651694</td>\n",
       "      <td>0.594167</td>\n",
       "      <td>-0.680426</td>\n",
       "      <td>0.887482</td>\n",
       "      <td>0.651587</td>\n",
       "      <td>-0.144996</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>5 rows × 768 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "        0         1         2         3         4         5         6    \\\n",
       "0  0.392252 -0.504846  0.056791  0.356297  0.475918 -0.648899  1.157862   \n",
       "1  0.387422 -0.481142  0.049675  0.353058  0.601170 -0.646099  1.142392   \n",
       "2  0.390975 -0.510056  0.070656  0.380695  0.601486 -0.595827  1.182193   \n",
       "3  0.423924 -0.557325  0.083810  0.328703  0.399589 -0.622818  1.079945   \n",
       "4  0.335576 -0.559591  0.119437  0.364141  0.375474 -0.639833  1.144707   \n",
       "\n",
       "        7         8         9    ...       758       759       760       761  \\\n",
       "0 -0.022914  0.703240  0.192023  ... -1.208714 -0.094441  0.128845  0.403995   \n",
       "1  0.060092  0.763799  0.110331  ... -1.248282 -0.139790  0.075585  0.202242   \n",
       "2  0.011085  0.688093  0.056453  ... -1.294595 -0.164846  0.194435  0.240742   \n",
       "3  0.097611  0.724030  0.135976  ... -1.412060 -0.106541  0.153314  0.209962   \n",
       "4  0.077512  0.791759  0.164201  ... -1.279041 -0.186733  0.106963  0.254949   \n",
       "\n",
       "        762       763       764       765       766       767  \n",
       "0 -0.782782  0.541907 -0.707272  0.901041  0.629461 -0.020630  \n",
       "1 -0.729794  0.705914 -0.771751  0.843173  0.618850 -0.213584  \n",
       "2 -0.773443  0.608631 -0.747181  0.791911  0.611874 -0.125455  \n",
       "3 -0.699690  0.648061 -0.716241  0.757986  0.615963 -0.258693  \n",
       "4 -0.651694  0.594167 -0.680426  0.887482  0.651587 -0.144996  \n",
       "\n",
       "[5 rows x 768 columns]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "with torch.no_grad():\n",
    "    df_embeddings_test = model_smi_ted.encode(df_test_normalized['norm_smiles'])\n",
    "df_embeddings_test.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Experiments - Lipophilicity prediction using smi-ted latent spaces"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "#### XGBoost prediction using the whole Latent Space"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "from xgboost import XGBRegressor\n",
    "from sklearn.metrics import mean_squared_error"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style>#sk-container-id-1 {\n",
       "  /* Definition of color scheme common for light and dark mode */\n",
       "  --sklearn-color-text: black;\n",
       "  --sklearn-color-line: gray;\n",
       "  /* Definition of color scheme for unfitted estimators */\n",
       "  --sklearn-color-unfitted-level-0: #fff5e6;\n",
       "  --sklearn-color-unfitted-level-1: #f6e4d2;\n",
       "  --sklearn-color-unfitted-level-2: #ffe0b3;\n",
       "  --sklearn-color-unfitted-level-3: chocolate;\n",
       "  /* Definition of color scheme for fitted estimators */\n",
       "  --sklearn-color-fitted-level-0: #f0f8ff;\n",
       "  --sklearn-color-fitted-level-1: #d4ebff;\n",
       "  --sklearn-color-fitted-level-2: #b3dbfd;\n",
       "  --sklearn-color-fitted-level-3: cornflowerblue;\n",
       "\n",
       "  /* Specific color for light theme */\n",
       "  --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
       "  --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
       "  --sklearn-color-icon: #696969;\n",
       "\n",
       "  @media (prefers-color-scheme: dark) {\n",
       "    /* Redefinition of color scheme for dark theme */\n",
       "    --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
       "    --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
       "    --sklearn-color-icon: #878787;\n",
       "  }\n",
       "}\n",
       "\n",
       "#sk-container-id-1 {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 pre {\n",
       "  padding: 0;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-hidden--visually {\n",
       "  border: 0;\n",
       "  clip: rect(1px 1px 1px 1px);\n",
       "  clip: rect(1px, 1px, 1px, 1px);\n",
       "  height: 1px;\n",
       "  margin: -1px;\n",
       "  overflow: hidden;\n",
       "  padding: 0;\n",
       "  position: absolute;\n",
       "  width: 1px;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-dashed-wrapped {\n",
       "  border: 1px dashed var(--sklearn-color-line);\n",
       "  margin: 0 0.4em 0.5em 0.4em;\n",
       "  box-sizing: border-box;\n",
       "  padding-bottom: 0.4em;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-container {\n",
       "  /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
       "     but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
       "     so we also need the `!important` here to be able to override the\n",
       "     default hidden behavior on the sphinx rendered scikit-learn.org.\n",
       "     See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
       "  display: inline-block !important;\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-text-repr-fallback {\n",
       "  display: none;\n",
       "}\n",
       "\n",
       "div.sk-parallel-item,\n",
       "div.sk-serial,\n",
       "div.sk-item {\n",
       "  /* draw centered vertical line to link estimators */\n",
       "  background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
       "  background-size: 2px 100%;\n",
       "  background-repeat: no-repeat;\n",
       "  background-position: center center;\n",
       "}\n",
       "\n",
       "/* Parallel-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item::after {\n",
       "  content: \"\";\n",
       "  width: 100%;\n",
       "  border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
       "  flex-grow: 1;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel {\n",
       "  display: flex;\n",
       "  align-items: stretch;\n",
       "  justify-content: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  position: relative;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
       "  align-self: flex-end;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
       "  align-self: flex-start;\n",
       "  width: 50%;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
       "  width: 0;\n",
       "}\n",
       "\n",
       "/* Serial-specific style estimator block */\n",
       "\n",
       "#sk-container-id-1 div.sk-serial {\n",
       "  display: flex;\n",
       "  flex-direction: column;\n",
       "  align-items: center;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  padding-right: 1em;\n",
       "  padding-left: 1em;\n",
       "}\n",
       "\n",
       "\n",
       "/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
       "clickable and can be expanded/collapsed.\n",
       "- Pipeline and ColumnTransformer use this feature and define the default style\n",
       "- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
       "*/\n",
       "\n",
       "/* Pipeline and ColumnTransformer style (default) */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable {\n",
       "  /* Default theme specific background. It is overwritten whether we have a\n",
       "  specific estimator or a Pipeline/ColumnTransformer */\n",
       "  background-color: var(--sklearn-color-background);\n",
       "}\n",
       "\n",
       "/* Toggleable label */\n",
       "#sk-container-id-1 label.sk-toggleable__label {\n",
       "  cursor: pointer;\n",
       "  display: block;\n",
       "  width: 100%;\n",
       "  margin-bottom: 0;\n",
       "  padding: 0.5em;\n",
       "  box-sizing: border-box;\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
       "  /* Arrow on the left of the label */\n",
       "  content: \"▸\";\n",
       "  float: left;\n",
       "  margin-right: 0.25em;\n",
       "  color: var(--sklearn-color-icon);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
       "  color: var(--sklearn-color-text);\n",
       "}\n",
       "\n",
       "/* Toggleable content - dropdown */\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content {\n",
       "  max-height: 0;\n",
       "  max-width: 0;\n",
       "  overflow: hidden;\n",
       "  text-align: left;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content pre {\n",
       "  margin: 0.2em;\n",
       "  border-radius: 0.25em;\n",
       "  color: var(--sklearn-color-text);\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
       "  /* Expand drop-down */\n",
       "  max-height: 200px;\n",
       "  max-width: 100%;\n",
       "  overflow: auto;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
       "  content: \"▾\";\n",
       "}\n",
       "\n",
       "/* Pipeline/ColumnTransformer-specific style */\n",
       "\n",
       "#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator-specific style */\n",
       "\n",
       "/* Colorize estimator box */\n",
       "#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  /* The background is the default theme color */\n",
       "  color: var(--sklearn-color-text-on-default-background);\n",
       "}\n",
       "\n",
       "/* On hover, darken the color of the background */\n",
       "#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "/* Label box, darken color on hover, fitted */\n",
       "#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
       "  color: var(--sklearn-color-text);\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Estimator label */\n",
       "\n",
       "#sk-container-id-1 div.sk-label label {\n",
       "  font-family: monospace;\n",
       "  font-weight: bold;\n",
       "  display: inline-block;\n",
       "  line-height: 1.2em;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-label-container {\n",
       "  text-align: center;\n",
       "}\n",
       "\n",
       "/* Estimator-specific */\n",
       "#sk-container-id-1 div.sk-estimator {\n",
       "  font-family: monospace;\n",
       "  border: 1px dotted var(--sklearn-color-border-box);\n",
       "  border-radius: 0.25em;\n",
       "  box-sizing: border-box;\n",
       "  margin-bottom: 0.5em;\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-0);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-0);\n",
       "}\n",
       "\n",
       "/* on hover */\n",
       "#sk-container-id-1 div.sk-estimator:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-2);\n",
       "}\n",
       "\n",
       "#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-2);\n",
       "}\n",
       "\n",
       "/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
       "\n",
       "/* Common style for \"i\" and \"?\" */\n",
       "\n",
       ".sk-estimator-doc-link,\n",
       "a:link.sk-estimator-doc-link,\n",
       "a:visited.sk-estimator-doc-link {\n",
       "  float: right;\n",
       "  font-size: smaller;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1em;\n",
       "  height: 1em;\n",
       "  width: 1em;\n",
       "  text-decoration: none !important;\n",
       "  margin-left: 1ex;\n",
       "  /* unfitted */\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted,\n",
       "a:link.sk-estimator-doc-link.fitted,\n",
       "a:visited.sk-estimator-doc-link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
       ".sk-estimator-doc-link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover,\n",
       "div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
       ".sk-estimator-doc-link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "/* Span, style for the box shown on hovering the info icon */\n",
       ".sk-estimator-doc-link span {\n",
       "  display: none;\n",
       "  z-index: 9999;\n",
       "  position: relative;\n",
       "  font-weight: normal;\n",
       "  right: .2ex;\n",
       "  padding: .5ex;\n",
       "  margin: .5ex;\n",
       "  width: min-content;\n",
       "  min-width: 20ex;\n",
       "  max-width: 50ex;\n",
       "  color: var(--sklearn-color-text);\n",
       "  box-shadow: 2pt 2pt 4pt #999;\n",
       "  /* unfitted */\n",
       "  background: var(--sklearn-color-unfitted-level-0);\n",
       "  border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link.fitted span {\n",
       "  /* fitted */\n",
       "  background: var(--sklearn-color-fitted-level-0);\n",
       "  border: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "\n",
       ".sk-estimator-doc-link:hover span {\n",
       "  display: block;\n",
       "}\n",
       "\n",
       "/* \"?\"-specific style due to the `<a>` HTML tag */\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link {\n",
       "  float: right;\n",
       "  font-size: 1rem;\n",
       "  line-height: 1em;\n",
       "  font-family: monospace;\n",
       "  background-color: var(--sklearn-color-background);\n",
       "  border-radius: 1rem;\n",
       "  height: 1rem;\n",
       "  width: 1rem;\n",
       "  text-decoration: none;\n",
       "  /* unfitted */\n",
       "  color: var(--sklearn-color-unfitted-level-1);\n",
       "  border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted {\n",
       "  /* fitted */\n",
       "  border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
       "  color: var(--sklearn-color-fitted-level-1);\n",
       "}\n",
       "\n",
       "/* On hover */\n",
       "#sk-container-id-1 a.estimator_doc_link:hover {\n",
       "  /* unfitted */\n",
       "  background-color: var(--sklearn-color-unfitted-level-3);\n",
       "  color: var(--sklearn-color-background);\n",
       "  text-decoration: none;\n",
       "}\n",
       "\n",
       "#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
       "  /* fitted */\n",
       "  background-color: var(--sklearn-color-fitted-level-3);\n",
       "}\n",
       "</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
       "             colsample_bylevel=None, colsample_bynode=None,\n",
       "             colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
       "             enable_categorical=False, eval_metric=None, feature_types=None,\n",
       "             gamma=None, grow_policy=None, importance_type=None,\n",
       "             interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
       "             max_cat_threshold=None, max_cat_to_onehot=None,\n",
       "             max_delta_step=None, max_depth=4, max_leaves=None,\n",
       "             min_child_weight=None, missing=nan, monotone_constraints=None,\n",
       "             multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
       "             num_parallel_tree=None, random_state=None, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;XGBRegressor<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
       "             colsample_bylevel=None, colsample_bynode=None,\n",
       "             colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
       "             enable_categorical=False, eval_metric=None, feature_types=None,\n",
       "             gamma=None, grow_policy=None, importance_type=None,\n",
       "             interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
       "             max_cat_threshold=None, max_cat_to_onehot=None,\n",
       "             max_delta_step=None, max_depth=4, max_leaves=None,\n",
       "             min_child_weight=None, missing=nan, monotone_constraints=None,\n",
       "             multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
       "             num_parallel_tree=None, random_state=None, ...)</pre></div> </div></div></div></div>"
      ],
      "text/plain": [
       "XGBRegressor(base_score=None, booster=None, callbacks=None,\n",
       "             colsample_bylevel=None, colsample_bynode=None,\n",
       "             colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
       "             enable_categorical=False, eval_metric=None, feature_types=None,\n",
       "             gamma=None, grow_policy=None, importance_type=None,\n",
       "             interaction_constraints=None, learning_rate=0.05, max_bin=None,\n",
       "             max_cat_threshold=None, max_cat_to_onehot=None,\n",
       "             max_delta_step=None, max_depth=4, max_leaves=None,\n",
       "             min_child_weight=None, missing=nan, monotone_constraints=None,\n",
       "             multi_strategy=None, n_estimators=2000, n_jobs=None,\n",
       "             num_parallel_tree=None, random_state=None, ...)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "xgb_predict = XGBRegressor(n_estimators=2000, learning_rate=0.05, max_depth=4)\n",
    "xgb_predict.fit(df_embeddings_train, df_train_normalized['y'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get XGBoost predictions\n",
    "y_pred = xgb_predict.predict(df_embeddings_test)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RMSE Score: 0.6485\n"
     ]
    }
   ],
   "source": [
    "rmse = np.sqrt(mean_squared_error(df_test_normalized[\"y\"], y_pred))\n",
    "print(f\"RMSE Score: {rmse:.4f}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}