{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "4adab348", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:15.774148Z", "iopub.status.busy": "2024-07-25T10:02:15.773596Z", "iopub.status.idle": "2024-07-25T10:02:26.820147Z", "shell.execute_reply": "2024-07-25T10:02:26.818935Z" }, "executionInfo": { "elapsed": 556, "status": "ok", "timestamp": 1697339009466, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "COkMuAOy2J5o", "papermill": { "duration": 11.059195, "end_time": "2024-07-25T10:02:26.822747", "exception": false, "start_time": "2024-07-25T10:02:15.763552", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Detected operating system as Ubuntu/focal.\r\n", "Checking for curl...\r\n", "Detected curl...\r\n", "Checking for gpg...\r\n", "Detected gpg...\r\n", "Detected apt version as 2.0.10\r\n", "Running apt-get update... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "done.\r\n", "Installing apt-transport-https... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "done.\r\n", "Installing /etc/apt/sources.list.d/github_git-lfs.list..." ] }, { "name": "stdout", "output_type": "stream", "text": [ "done.\r\n", "Importing packagecloud gpg key... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "Packagecloud gpg key imported to /etc/apt/keyrings/github_git-lfs-archive-keyring.gpg\r\n", "done.\r\n", "Running apt-get update... " ] }, { "name": "stdout", "output_type": "stream", "text": [ "done.\r\n", "\r\n", "The repository is setup! You can now install packages.\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 0%\r", "\r", "Reading package lists... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 3%\r", "\r", "Reading package lists... 3%\r", "\r", "Reading package lists... 3%\r", "\r", "Reading package lists... 3%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 35%\r", "\r", "Reading package lists... 35%\r", "\r", "Reading package lists... 36%\r", "\r", "Reading package lists... 36%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 46%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 48%\r", "\r", "Reading package lists... 48%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 60%\r", "\r", "Reading package lists... 60%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 64%\r", "\r", "Reading package lists... 64%\r", "\r", "Reading package lists... 64%\r", "\r", "Reading package lists... 64%\r", "\r", "Reading package lists... 65%\r", "\r", "Reading package lists... 65%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 65%\r", "\r", "Reading package lists... 65%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 76%\r", "\r", "Reading package lists... 76%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 87%\r", "\r", "Reading package lists... 87%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 91%\r", "\r", "Reading package lists... 91%\r", "\r", "Reading package lists... 91%\r", "\r", "Reading package lists... 91%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 96%\r", "\r", "Reading package lists... 96%\r", "\r", "Reading package lists... 96%\r", "\r", "Reading package lists... 96%\r", "\r", "Reading package lists... 96%\r", "\r", "Reading package lists... 96%\r", "\r", "Reading package lists... 96%\r", "\r", "Reading package lists... 96%\r", "\r", "Reading package lists... 98%\r", "\r", "Reading package lists... 98%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... 99%\r", "\r", "Reading package lists... Done\r", "\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree... 0%\r", "\r", "Building dependency tree... 0%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree... 50%\r", "\r", "Building dependency tree... 50%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree... 73%\r" ] }, { "name": "stdout", "output_type": "stream", "text": [ "\r", "Building dependency tree \r", "\r\n", "\r", "Reading state information... 0%\r", "\r", "Reading state information... 0%\r", "\r", "Reading state information... Done\r", "\r\n", "git-lfs is already the newest version (3.5.1).\r\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "0 upgraded, 0 newly installed, 0 to remove and 88 not upgraded.\r\n" ] } ], "source": [ "!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.deb.sh | sudo bash\n", "!apt-get install git-lfs --upgrade" ] }, { "cell_type": "code", "execution_count": 2, "id": "56c68fad", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:26.845725Z", "iopub.status.busy": "2024-07-25T10:02:26.845375Z", "iopub.status.idle": "2024-07-25T10:02:26.853194Z", "shell.execute_reply": "2024-07-25T10:02:26.852360Z" }, "papermill": { "duration": 0.021649, "end_time": "2024-07-25T10:02:26.855275", "exception": false, "start_time": "2024-07-25T10:02:26.833626", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "'\\n%cd /kaggle/working\\n!rm -rf ml-utility-loss\\n!git clone https://github.com/R-N/ml-utility-loss\\n%cd ml-utility-loss\\n!git pull\\n!rm setup.py\\n!curl -Lo setup.py https://github.com/R-N/ml-utility-loss/raw/main/setup.py\\n!pip install .\\n#!pip install . --no-deps --force-reinstall --upgrade\\n#'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "\"\"\"\n", "%cd /kaggle/working\n", "!rm -rf ml-utility-loss\n", "!git clone https://github.com/R-N/ml-utility-loss\n", "%cd ml-utility-loss\n", "!git pull\n", "!rm setup.py\n", "!curl -Lo setup.py https://github.com/R-N/ml-utility-loss/raw/main/setup.py\n", "!pip install .\n", "#!pip install . --no-deps --force-reinstall --upgrade\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "c6629a73", "metadata": { "papermill": { "duration": 0.009695, "end_time": "2024-07-25T10:02:26.874805", "exception": false, "start_time": "2024-07-25T10:02:26.865110", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 3, "id": "0ad13705", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:26.896431Z", "iopub.status.busy": "2024-07-25T10:02:26.896129Z", "iopub.status.idle": "2024-07-25T10:02:26.903718Z", "shell.execute_reply": "2024-07-25T10:02:26.902748Z" }, "papermill": { "duration": 0.020798, "end_time": "2024-07-25T10:02:26.905598", "exception": false, "start_time": "2024-07-25T10:02:26.884800", "status": "completed" }, "tags": [ "parameters" ] }, "outputs": [], "source": [ "import os\n", "\n", "datasets = [\n", " \"insurance\",\n", " \"treatment\",\n", " \"contraceptive\"\n", "]\n", "models = [\"tvae\"]\n", "single_model = \"tvae\"\n", "\n", "\n", "model_dir = \".\"\n", "model_dir_2 = None\n", "study_dir = \"./\"\n", "\n", "path_prefix = \"../../../../\"\n", "\n", "dataset_dir=os.path.join(path_prefix, \"ml-utility-loss\", \"datasets\")\n", "dataset_name = \"contraceptive\"\n", "\n", "direction = \"maximize\"\n", "model_name = \"tvae_mlu\"\n", "\n", "mlu_model_dir=os.path.join(path_prefix, \"final\")\n", "mlu_model_name = \"tvae\"\n", "mlu_run = True\n", "\n", "gp = True\n", "gp_multiply = True\n", "\n", "df_name = \"df\"\n", "\n", "folder = \"eval\"\n", "path = None\n", "debug = False\n", "\n", "param_index = 0\n", "repo_index = 5\n", "use_all_data = False\n", "\n", "epoch_scale = 1\n", "save_model = True" ] }, { "cell_type": "code", "execution_count": 4, "id": "74336c5a", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:26.926685Z", "iopub.status.busy": "2024-07-25T10:02:26.926428Z", "iopub.status.idle": "2024-07-25T10:02:26.931486Z", "shell.execute_reply": "2024-07-25T10:02:26.930616Z" }, "papermill": { "duration": 0.01795, "end_time": "2024-07-25T10:02:26.933423", "exception": false, "start_time": "2024-07-25T10:02:26.915473", "status": "completed" }, "tags": [ "injected-parameters" ] }, "outputs": [], "source": [ "# Parameters\n", "dataset = \"iris\"\n", "dataset_name = \"iris\"\n", "gp = False\n", "gp_multiply = False\n", "df_name = 1\n", "folder = \"eval\"\n", "path_prefix = \"../../../../\"\n", "path = \"eval/iris/tvae/1\"\n", "model_dir = \".\"\n", "model_dir_2 = \".\"\n", "param_index = 0\n", "use_all_data = False\n", "repo_index = 5\n", "save_model = True\n" ] }, { "cell_type": "code", "execution_count": 5, "id": "b3a74a5a", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:26.955784Z", "iopub.status.busy": "2024-07-25T10:02:26.955507Z", "iopub.status.idle": "2024-07-25T10:02:26.961523Z", "shell.execute_reply": "2024-07-25T10:02:26.960707Z" }, "papermill": { "duration": 0.019274, "end_time": "2024-07-25T10:02:26.963567", "exception": false, "start_time": "2024-07-25T10:02:26.944293", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mlu_est_eval_5_nogp_iris\n" ] } ], "source": [ "repo = \"mlu_est_eval\"\n", "if repo_index:\n", " repo = f\"{repo}_{repo_index}\"\n", "if gp:\n", " if gp_multiply:\n", " repo = f\"{repo}_gp_mul\"\n", " else:\n", " repo = f\"{repo}_gp_nomul\"\n", "else:\n", " repo = f\"{repo}_nogp\"\n", "repo = f\"{repo}_{dataset_name}\"\n", "print(repo)" ] }, { "cell_type": "code", "execution_count": 6, "id": "c8e1a018", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:26.985233Z", "iopub.status.busy": "2024-07-25T10:02:26.984952Z", "iopub.status.idle": "2024-07-25T10:02:27.988087Z", "shell.execute_reply": "2024-07-25T10:02:27.986982Z" }, "papermill": { "duration": 1.016723, "end_time": "2024-07-25T10:02:27.990691", "exception": false, "start_time": "2024-07-25T10:02:26.973968", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "fatal: destination path 'final' already exists and is not an empty directory.\r\n" ] } ], "source": [ "#\"\"\"\n", "%cd /kaggle/working\n", "#!rm -rf final\n", "!git clone https://huggingface.co/linearch/{repo} final\n", "#\"\"\"" ] }, { "cell_type": "code", "execution_count": 7, "id": "7b1f4abe", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:28.013717Z", "iopub.status.busy": "2024-07-25T10:02:28.013389Z", "iopub.status.idle": "2024-07-25T10:02:29.029116Z", "shell.execute_reply": "2024-07-25T10:02:29.027822Z" }, "papermill": { "duration": 1.029845, "end_time": "2024-07-25T10:02:29.031314", "exception": false, "start_time": "2024-07-25T10:02:28.001469", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/iris\n" ] } ], "source": [ "%cd /kaggle/working/\n", "#!rm -rf {dataset_name}\n", "!mkdir {dataset_name}\n", "%cd {dataset_name}" ] }, { "cell_type": "code", "execution_count": 8, "id": "1df581ed", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:29.055190Z", "iopub.status.busy": "2024-07-25T10:02:29.054408Z", "iopub.status.idle": "2024-07-25T10:02:29.064422Z", "shell.execute_reply": "2024-07-25T10:02:29.063524Z" }, "papermill": { "duration": 0.024359, "end_time": "2024-07-25T10:02:29.066488", "exception": false, "start_time": "2024-07-25T10:02:29.042129", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working\n", "/kaggle/working/eval/iris/tvae/1\n" ] } ], "source": [ "from pathlib import Path\n", "import os\n", "\n", "%cd /kaggle/working/\n", "\n", "if path is None:\n", " path = os.path.join(folder, dataset_name, single_model, df_name)\n", "Path(path).mkdir(parents=True, exist_ok=True)\n", "\n", "%cd {path}" ] }, { "cell_type": "code", "execution_count": 9, "id": "7893e955", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:29.090909Z", "iopub.status.busy": "2024-07-25T10:02:29.090611Z", "iopub.status.idle": "2024-07-25T10:02:29.460948Z", "shell.execute_reply": "2024-07-25T10:02:29.459829Z" }, "executionInfo": { "elapsed": 573, "status": "ok", "timestamp": 1697340246093, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "UdvXYv3c3LXy", "papermill": { "duration": 0.385315, "end_time": "2024-07-25T10:02:29.463524", "exception": false, "start_time": "2024-07-25T10:02:29.078209", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import json\n", "import os\n", "\n", "df = pd.read_csv(os.path.join(dataset_dir, f\"{dataset_name}.csv\"))\n", "with open(os.path.join(dataset_dir, f\"{dataset_name}.json\")) as f:\n", " info = json.load(f)" ] }, { "cell_type": "code", "execution_count": 10, "id": "f038419d", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:29.487027Z", "iopub.status.busy": "2024-07-25T10:02:29.486672Z", "iopub.status.idle": "2024-07-25T10:02:31.140841Z", "shell.execute_reply": "2024-07-25T10:02:31.140047Z" }, "executionInfo": { "elapsed": 17, "status": "ok", "timestamp": 1697340246097, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "Vrl2QkoV3o_8", "papermill": { "duration": 1.668648, "end_time": "2024-07-25T10:02:31.143213", "exception": false, "start_time": "2024-07-25T10:02:29.474565", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.util import split_df_ratio\n", "\n", "task = info[\"task\"]\n", "target = info[\"target\"]\n", "cat_features = info[\"cat_features\"]\n", "mixed_features = info[\"mixed_features\"]\n", "longtail_features = info[\"longtail_features\"]\n", "integer_features = info[\"integer_features\"]\n", "\n", "dfs = {\n", " \"df\": df,\n", "}\n", "dfs_test = {}\n", "for i in range(5):\n", " train, test = split_df_ratio(df, ratio=0.2, i=i, seed=42)\n", " dfs[i] = train\n", " dfs_test[i] = test\n" ] }, { "cell_type": "code", "execution_count": 11, "id": "452b822e", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:31.169302Z", "iopub.status.busy": "2024-07-25T10:02:31.168320Z", "iopub.status.idle": "2024-07-25T10:02:31.176458Z", "shell.execute_reply": "2024-07-25T10:02:31.175677Z" }, "executionInfo": { "elapsed": 365, "status": "ok", "timestamp": 1697343112569, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "a-SjylvlYl7i", "papermill": { "duration": 0.023424, "end_time": "2024-07-25T10:02:31.178478", "exception": false, "start_time": "2024-07-25T10:02:31.155054", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.util import mkdir, seed\n", "\n", "if isinstance(df_name, int) or df_name.isdigit():\n", " seed(int(df_name))\n", "else:\n", " seed(0)\n", "#model_name_2 = f\"{model_name}_{dataset_name}_{df_name}\"\n", "if model_dir_2 is None:\n", " model_dir_2 = os.path.join(model_dir, model_name, dataset_name, str(df_name))\n", "mkdir(model_dir_2)\n", "model_path = os.path.join(model_dir_2, f\"model.pt\")\n", "state_path = os.path.join(model_dir_2, f\"state.json\")\n", "params_path = os.path.join(model_dir_2, f\"params.json\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "660b1b9c", "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "execution": { "iopub.execute_input": "2024-07-25T10:02:31.201571Z", "iopub.status.busy": "2024-07-25T10:02:31.201284Z", "iopub.status.idle": "2024-07-25T10:02:31.205819Z", "shell.execute_reply": "2024-07-25T10:02:31.204961Z" }, "executionInfo": { "elapsed": 4, "status": "ok", "timestamp": 1697343113004, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "-YDZUU9QUCwx", "outputId": "3d0fe03c-d2ee-4955-9589-0be8e213faf1", "papermill": { "duration": 0.018257, "end_time": "2024-07-25T10:02:31.207729", "exception": false, "start_time": "2024-07-25T10:02:31.189472", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "./model.pt\n" ] } ], "source": [ "print(model_path)" ] }, { "cell_type": "code", "execution_count": 13, "id": "8b43d834", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:31.230860Z", "iopub.status.busy": "2024-07-25T10:02:31.230570Z", "iopub.status.idle": "2024-07-25T10:02:31.348607Z", "shell.execute_reply": "2024-07-25T10:02:31.347319Z" }, "papermill": { "duration": 0.132637, "end_time": "2024-07-25T10:02:31.351348", "exception": false, "start_time": "2024-07-25T10:02:31.218711", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import ml_utility_loss.synthesizers.tab_ddpm.params as TAB_DDPM_PARAMS\n", "import ml_utility_loss.synthesizers.lct_gan.params as LCT_GAN_PARAMS\n", "import ml_utility_loss.synthesizers.realtabformer.params as RTF_PARAMS\n", "from ml_utility_loss.synthesizers.realtabformer.params.default import GPT2_PARAMS, REALTABFORMER_PARAMS\n", "from ml_utility_loss.util import filter_dict_2, filter_dict\n", "\n", "tab_ddpm_params = getattr(TAB_DDPM_PARAMS, dataset_name).BEST\n", "lct_gan_params = getattr(LCT_GAN_PARAMS, dataset_name).BEST\n", "lct_ae_params = filter_dict_2(lct_gan_params, LCT_GAN_PARAMS.default.AE_PARAMS)\n", "rtf_params = getattr(RTF_PARAMS, dataset_name).BEST\n", "rtf_params = filter_dict(rtf_params, REALTABFORMER_PARAMS)\n", "\n", "lct_ae_embedding_size=lct_gan_params[\"embedding_size\"]\n", "tab_ddpm_normalization=\"quantile\"\n", "tab_ddpm_cat_encoding=tab_ddpm_params[\"cat_encoding\"]\n", "#tab_ddpm_cat_encoding=\"one-hot\"\n", "tab_ddpm_y_policy=\"default\"\n", "tab_ddpm_is_y_cond=True" ] }, { "cell_type": "code", "execution_count": 14, "id": "27ebb9df", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:31.377612Z", "iopub.status.busy": "2024-07-25T10:02:31.377296Z", "iopub.status.idle": "2024-07-25T10:02:35.788721Z", "shell.execute_reply": "2024-07-25T10:02:35.787657Z" }, "papermill": { "duration": 4.426749, "end_time": "2024-07-25T10:02:35.791028", "exception": false, "start_time": "2024-07-25T10:02:31.364279", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2024-07-25 10:02:33.149497: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", "2024-07-25 10:02:33.149554: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", "2024-07-25 10:02:33.151052: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "mlu_run True\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_lct_ae\n", "\n", "if isinstance(mlu_run, (int, str)) or mlu_run is True:\n", " print(\"mlu_run\", mlu_run)\n", "# lct_ae = load_lct_ae(\n", "# dataset_name=dataset_name,\n", "# model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", "# model_name=\"lct_ae\",\n", "# df_name=\"df\",\n", "# )\n", "lct_ae = None" ] }, { "cell_type": "code", "execution_count": 15, "id": "4c3a950b", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:35.815719Z", "iopub.status.busy": "2024-07-25T10:02:35.815118Z", "iopub.status.idle": "2024-07-25T10:02:35.822552Z", "shell.execute_reply": "2024-07-25T10:02:35.821805Z" }, "papermill": { "duration": 0.021794, "end_time": "2024-07-25T10:02:35.824493", "exception": false, "start_time": "2024-07-25T10:02:35.802699", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_rtf_embed\n", "\n", "rtf_embed = None\n", "if isinstance(mlu_run, (int, str)) or mlu_run is True:\n", " rtf_embed = load_rtf_embed(\n", " dataset_name=dataset_name,\n", " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", " model_name=\"realtabformer\",\n", " df_name=\"df\",\n", " ckpt_type=\"best-disc-model\"\n", " )" ] }, { "cell_type": "code", "execution_count": 16, "id": "e69f9844", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:35.850460Z", "iopub.status.busy": "2024-07-25T10:02:35.850166Z", "iopub.status.idle": "2024-07-25T10:02:38.616751Z", "shell.execute_reply": "2024-07-25T10:02:38.615928Z" }, "papermill": { "duration": 2.782591, "end_time": "2024-07-25T10:02:38.619327", "exception": false, "start_time": "2024-07-25T10:02:35.836736", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", " warnings.warn(\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.preprocessing import DataPreprocessor\n", "\n", "preprocessor = DataPreprocessor(\n", " task,\n", " target=target,\n", " cat_features=cat_features,\n", " mixed_features=mixed_features,\n", " longtail_features=longtail_features,\n", " integer_features=integer_features,\n", " lct_ae_embedding_size=lct_ae_embedding_size,\n", " lct_ae_params=lct_ae_params,\n", " lct_ae=lct_ae,\n", " tab_ddpm_normalization=tab_ddpm_normalization,\n", " tab_ddpm_cat_encoding=tab_ddpm_cat_encoding,\n", " tab_ddpm_y_policy=tab_ddpm_y_policy,\n", " tab_ddpm_is_y_cond=tab_ddpm_is_y_cond,\n", " realtabformer_embedding=rtf_embed,\n", " realtabformer_params=rtf_params,\n", ")\n", "preprocessor.fit(df)" ] }, { "cell_type": "code", "execution_count": 17, "id": "0f77640d", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:38.646712Z", "iopub.status.busy": "2024-07-25T10:02:38.646004Z", "iopub.status.idle": "2024-07-25T10:02:38.653417Z", "shell.execute_reply": "2024-07-25T10:02:38.652338Z" }, "papermill": { "duration": 0.023888, "end_time": "2024-07-25T10:02:38.655944", "exception": false, "start_time": "2024-07-25T10:02:38.632056", "status": "completed" }, "tags": [] }, "outputs": [ { "data": { "text/plain": [ "{'tvae': 24,\n", " 'realtabformer': (31, 89, Embedding(89, 864), True),\n", " 'lct_gan': 14,\n", " 'tab_ddpm_concat': 5}" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "preprocessor.adapter_sizes" ] }, { "cell_type": "code", "execution_count": 18, "id": "be5be29a", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:38.681579Z", "iopub.status.busy": "2024-07-25T10:02:38.681315Z", "iopub.status.idle": "2024-07-25T10:02:39.199181Z", "shell.execute_reply": "2024-07-25T10:02:39.198095Z" }, "papermill": { "duration": 0.533117, "end_time": "2024-07-25T10:02:39.201704", "exception": false, "start_time": "2024-07-25T10:02:38.668587", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import create_model, ISABMode, LoRAMode\n", "import ml_utility_loss.loss_learning.estimator.params2 as PARAMS\n", "from ml_utility_loss.tuning import map_parameters\n", "from ml_utility_loss.util import clear_memory\n", "import torch\n", "import json\n", "\n", "clear_memory()\n", "if isinstance(mlu_run, (int, str)) or mlu_run is True:\n", " param_space = {\n", " **getattr(PARAMS, dataset_name).PARAM_SPACE,\n", " #**getattr(PARAMS, dataset_name).PARAM_SPACE_2\n", " }\n", " params = getattr(PARAMS, dataset_name).BEST_DICT[gp][gp_multiply][single_model]\n", " #params = PARAMS.default.update_params_2(params, info[\"sizes\"])\n", " params[\"single_model\"] = False\n", " if models:\n", " params[\"models\"] = models\n", " if single_model:\n", " params[\"fixed_role_model\"] = single_model\n", " params[\"single_model\"] = True\n", " params[\"models\"] = [single_model]\n", " if gp:\n", " params[\"gradient_penalty_mode\"] = \"ALL\"\n", " params[\"mse_mag\"] = True\n", " if gp_multiply:\n", " params[\"mse_mag_multiply\"] = True\n", " #params[\"mse_mag_target\"] = 1.0\n", " else:\n", " params[\"mse_mag_multiply\"] = False\n", " #params[\"mse_mag_target\"] = 0.1\n", " else:\n", " params[\"gradient_penalty_mode\"] = \"NONE\"\n", " params[\"mse_mag\"] = False\n", " with open(\"params.json\", \"w\") as f:\n", " json.dump(params, f)\n", " params = map_parameters(params, param_space=param_space)\n", " params" ] }, { "cell_type": "code", "execution_count": 19, "id": "7bc7b4bd", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:39.229060Z", "iopub.status.busy": "2024-07-25T10:02:39.228633Z", "iopub.status.idle": "2024-07-25T10:02:39.248717Z", "shell.execute_reply": "2024-07-25T10:02:39.247788Z" }, "papermill": { "duration": 0.036244, "end_time": "2024-07-25T10:02:39.250847", "exception": false, "start_time": "2024-07-25T10:02:39.214603", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Creating model of type \n", "[*] Embedding False True\n", "['tvae'] 1\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.model.pipeline import remove_non_model_params\n", "\n", "if isinstance(mlu_run, (int, str)) or mlu_run is True:\n", " params2 = remove_non_model_params(params)\n", " mlu_model = create_model(\n", " adapters=preprocessor.adapter_sizes,\n", " #Body=\"twin_encoder\",\n", " **params2,\n", " )\n", " #cf.apply_weight_standardization(model, n_last_layers_ignore=0)\n", " print(mlu_model.models, len(mlu_model.adapters))\n", "else:\n", " mlu_model = None\n", " mlu_trainer = None\n", " ae_mlu_trainer = None\n", " gan_mlu_trainer = None" ] }, { "cell_type": "code", "execution_count": 20, "id": "c8a7d7c4", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:39.277039Z", "iopub.status.busy": "2024-07-25T10:02:39.276688Z", "iopub.status.idle": "2024-07-25T10:02:39.325843Z", "shell.execute_reply": "2024-07-25T10:02:39.324824Z" }, "papermill": { "duration": 0.06449, "end_time": "2024-07-25T10:02:39.327867", "exception": false, "start_time": "2024-07-25T10:02:39.263377", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "{'t_start': 0, 't_end': 904, 't_range': None, 'n_steps': 1, 'n_inner_steps': 1, 'n_inner_steps_2': 1, 'loss_mul': 1, 'div_batch': True, 'forgive_over': True, 'n_real': 0, 'n_samples': 64, 't_steps': 16, 'mlu_run': 3, 'target': None, 'lr': 0.004747413139187287, 'loss_fn': , 'Optim': functools.partial(, amsgrad=True)}\n" ] } ], "source": [ "import ml_utility_loss.synthesizers.tvae.params2 as MLU_PARAMS\n", "from ml_utility_loss.tuning import map_parameters\n", "if isinstance(mlu_run, (int, str)) or mlu_run is True:\n", " mlu_params0 = getattr(MLU_PARAMS, dataset_name)\n", " mlu_params = mlu_params0.BEST_DICT[gp][gp_multiply]\n", " if isinstance(mlu_params, (list, tuple)):\n", " mlu_params = mlu_params[param_index]\n", " mlu_params = map_parameters(mlu_params, param_space=mlu_params0.PARAM_SPACE)\n", " mlu_params[\"target\"] = mlu_params.pop(\"mlu_target\", mlu_params.pop(\"target\", None))\n", " mlu_params[\"lr\"] = mlu_params.pop(\"mlu_lr\", mlu_params.pop(\"lr\", None))\n", " mlu_params[\"loss_fn\"] = mlu_params.pop(\"mlu_loss_fn\", mlu_params.pop(\"loss_fn\", None))\n", " mlu_params[\"Optim\"] = mlu_params.pop(\"mlu_Optim\", mlu_params.pop(\"Optim\", None))\n", " mlu_params.pop(\"gradient_penalty_kwargs\", None)\n", " print(mlu_params)" ] }, { "cell_type": "code", "execution_count": 21, "id": "4e7b8902", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:39.356684Z", "iopub.status.busy": "2024-07-25T10:02:39.355803Z", "iopub.status.idle": "2024-07-25T10:02:39.375410Z", "shell.execute_reply": "2024-07-25T10:02:39.374657Z" }, "papermill": { "duration": 0.036076, "end_time": "2024-07-25T10:02:39.377441", "exception": false, "start_time": "2024-07-25T10:02:39.341365", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import torch\n", "\n", "if isinstance(mlu_run, (int, str)) or mlu_run is True:\n", " if mlu_run is True:\n", " mlu_run = mlu_params[\"mlu_run\"]\n", " mlu_params.pop(\"mlu_run\", None)\n", " mlu_model_dir_2 = os.path.join(mlu_model_dir, dataset_name, mlu_model_name, str(mlu_run))\n", " mlu_model_path = os.path.join(mlu_model_dir_2, f\"model.pt\")\n", "\n", " mlu_model.load_state_dict(torch.load(mlu_model_path))" ] }, { "cell_type": "code", "execution_count": 22, "id": "bdc96d0e", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:39.403618Z", "iopub.status.busy": "2024-07-25T10:02:39.403285Z", "iopub.status.idle": "2024-07-25T10:02:39.407319Z", "shell.execute_reply": "2024-07-25T10:02:39.406563Z" }, "papermill": { "duration": 0.019665, "end_time": "2024-07-25T10:02:39.409565", "exception": false, "start_time": "2024-07-25T10:02:39.389900", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "df_name_2 = \"df\" if use_all_data else df_name" ] }, { "cell_type": "code", "execution_count": 23, "id": "05302fea", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:39.435599Z", "iopub.status.busy": "2024-07-25T10:02:39.435282Z", "iopub.status.idle": "2024-07-25T10:02:39.828608Z", "shell.execute_reply": "2024-07-25T10:02:39.827727Z" }, "papermill": { "duration": 0.409336, "end_time": "2024-07-25T10:02:39.831086", "exception": false, "start_time": "2024-07-25T10:02:39.421750", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Caching in ../../../../iris/_cache_inference/tvae/all inf False\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.pipeline import load_dataset\n", "\n", "if isinstance(mlu_run, (int, str)) or mlu_run is True:\n", " dataset = load_dataset(\n", " dataset_dir=os.path.join(path_prefix, \"ml-utility-loss/\", \"synthetics\", dataset_name),\n", " preprocessor=preprocessor,\n", " cache_dir=os.path.join(path_prefix, dataset_name, \"_cache_inference\"),\n", " val=False,\n", " ratio=None,\n", " drop_first_column=True,\n", " model=single_model,\n", " train=\"train\", test=\"test\", value=\"real_value\",\n", " file=\"info_2.csv\",\n", " )" ] }, { "cell_type": "code", "execution_count": 24, "id": "98d1e706", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:39.860960Z", "iopub.status.busy": "2024-07-25T10:02:39.859861Z", "iopub.status.idle": "2024-07-25T10:02:39.869886Z", "shell.execute_reply": "2024-07-25T10:02:39.868977Z" }, "papermill": { "duration": 0.027367, "end_time": "2024-07-25T10:02:39.871900", "exception": false, "start_time": "2024-07-25T10:02:39.844533", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "mlu step every 16 starting 0 until 904\n", "mlu step times 1*1*1\n", "mlu samples 64 / 120\n", "mlu logging ./mlu_log.csv\n" ] } ], "source": [ "from ml_utility_loss.loss_learning.estimator.wrapper import MLUtilityTrainer\n", "\n", "if isinstance(mlu_run, (int, str)) or mlu_run is True:\n", " mlu_trainer = MLUtilityTrainer(\n", " model=mlu_model[single_model],\n", " dataset=dataset,\n", " debug=True,\n", " log_path=os.path.join(model_dir_2, \"mlu_log.csv\"),\n", " **mlu_params,\n", " )" ] }, { "cell_type": "code", "execution_count": 25, "id": "13ec9d9e", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:39.898068Z", "iopub.status.busy": "2024-07-25T10:02:39.897765Z", "iopub.status.idle": "2024-07-25T10:02:39.905644Z", "shell.execute_reply": "2024-07-25T10:02:39.904966Z" }, "executionInfo": { "elapsed": 3, "status": "ok", "timestamp": 1697343113004, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "NgahtU1q9uLO", "papermill": { "duration": 0.023213, "end_time": "2024-07-25T10:02:39.907561", "exception": false, "start_time": "2024-07-25T10:02:39.884348", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "\n", "from ml_utility_loss.tuning import create_objective\n", "import ml_utility_loss.synthesizers.tvae.params as PARAMS\n", "from ml_utility_loss.util import filter_dict_2\n", "\n", "params = getattr(PARAMS, dataset_name).BEST\n", "\n", "model_params={\n", " **params,\n", "}\n", "\n", "for x in [\"compress\", \"decompress\"]:\n", " model_params[f\"{x}_dims\"] = [\n", " model_params[f\"{x}_dims\"]\n", " for i in range(\n", " model_params.pop(f\"{x}_depth\")\n", " )\n", " ]\n", "\n", "model_params[\"epochs\"] = int(round(epoch_scale * model_params[\"epochs\"]))\n", "\n", "model_params[\"mlu_trainer\"] = mlu_trainer" ] }, { "cell_type": "code", "execution_count": 26, "id": "3ab8966e", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:02:39.934278Z", "iopub.status.busy": "2024-07-25T10:02:39.934006Z", "iopub.status.idle": "2024-07-25T10:03:52.779270Z", "shell.execute_reply": "2024-07-25T10:03:52.778268Z" }, "executionInfo": { "elapsed": 286822, "status": "ok", "timestamp": 1697343399823, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "wGsDQTlk8Zl1", "papermill": { "duration": 72.861152, "end_time": "2024-07-25T10:03:52.781504", "exception": false, "start_time": "2024-07-25T10:02:39.920352", "status": "completed" }, "tags": [] }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.012798070907592773\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.006090521812438965\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.004310190677642822\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.004614889621734619\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0051451921463012695\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0048449039459228516\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0075307488441467285\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.006430864334106445\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.007298827171325684\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.006350696086883545\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0059430599212646484\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "MLU loss 0.0\n" ] }, { "data": { "text/plain": [ "(-16.4831600189209, 5.026618957519531)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from ml_utility_loss.synthesizers.tvae.wrapper import TVAE\n", "from ml_utility_loss.synthesizers.tvae.process import preprocess\n", "import inspect\n", "\n", "train = dfs[df_name_2]\n", "#transformer, *_ = preprocess(df, cat_features)\n", "tvae = TVAE(**model_params)\n", "#print(inspect.getargspec(tvae.fit))\n", "tvae.fit(train, cat_features, preprocess_df=df)" ] }, { "cell_type": "code", "execution_count": 27, "id": "e609d51f", "metadata": { "execution": { "iopub.execute_input": "2024-07-25T10:03:52.819871Z", "iopub.status.busy": "2024-07-25T10:03:52.819509Z", "iopub.status.idle": "2024-07-25T10:03:52.848683Z", "shell.execute_reply": "2024-07-25T10:03:52.847933Z" }, "executionInfo": { "elapsed": 403, "status": "ok", "timestamp": 1697343400203, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "qUHHYJNRJdDy", "papermill": { "duration": 0.050969, "end_time": "2024-07-25T10:03:52.850738", "exception": false, "start_time": "2024-07-25T10:03:52.799769", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "import torch\n", "from copy import deepcopy\n", "import json\n", "\n", "if save_model:\n", " tvae.mlu_trainer = None\n", " if True or not os.path.exists(model_path):\n", " torch.save(tvae, model_path)\n", " if True or not os.path.exists(state_path):\n", " torch.save(deepcopy(tvae.model.state_dict()), state_path)\n", "model_params.pop(\"mlu_trainer\", None)\n", "if True or not os.path.exists(params_path):\n", " with open(params_path, \"w\") as f:\n", " json.dump(model_params, f, indent=4)" ] }, { "cell_type": "code", "execution_count": 28, "id": "40785600", "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 228 }, "execution": { "iopub.execute_input": "2024-07-25T10:03:52.887743Z", "iopub.status.busy": "2024-07-25T10:03:52.887422Z", "iopub.status.idle": "2024-07-25T10:03:53.738766Z", "shell.execute_reply": "2024-07-25T10:03:53.737977Z" }, "executionInfo": { "elapsed": 13, "status": "error", "timestamp": 1697343400203, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "7SQgzf5-P249", "outputId": "57519fd1-a4d2-4922-c834-1fe290391c16", "papermill": { "duration": 0.873221, "end_time": "2024-07-25T10:03:53.741234", "exception": false, "start_time": "2024-07-25T10:03:52.868013", "status": "completed" }, "tags": [] }, "outputs": [], "source": [ "for i in range(30):\n", " seed(i)\n", " synth_df = tvae.sample(len(train))\n", " model_dir_3 = os.path.join(model_dir_2, str(i))\n", " mkdir(model_dir_3)\n", " synth_path = os.path.join(model_dir_3, f\"synth.csv\")\n", " train_path = os.path.join(model_dir_3, f\"train.csv\")\n", " val_path = os.path.join(model_dir_3, f\"val.csv\")\n", " test_path = os.path.join(model_dir_3, f\"test.csv\")\n", "\n", " synth_df.to_csv(synth_path)\n", " dfs[df_name_2].to_csv(train_path)\n", " if df_name in dfs_test:\n", " dfs_test[df_name].to_csv(val_path)\n", " dfs_test[df_name].to_csv(test_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "0e1264cb", "metadata": { "executionInfo": { "elapsed": 10, "status": "aborted", "timestamp": 1697343400204, "user": { "displayName": "Rizqi Nur", "userId": "09644007964068789560" }, "user_tz": -420 }, "id": "nB724X33v4Qd", "papermill": { "duration": 0.018263, "end_time": "2024-07-25T10:03:53.778643", "exception": false, "start_time": "2024-07-25T10:03:53.760380", "status": "completed" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "celltoolbar": "Tags", "colab": { "authorship_tag": "ABX9TyMhtJHOKAMnUJyVHa+D8Sml", "mount_file_id": "1Cug9laqjkt9fyDxiylSn9Jzam9kQyDu3", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" }, "papermill": { "default_parameters": {}, "duration": 102.090619, "end_time": "2024-07-25T10:03:56.404298", "environment_variables": {}, "exception": null, "input_path": "eval/iris/tvae/1/tvae_eval.ipynb", "output_path": "eval/iris/tvae/1/tvae_eval.ipynb", "parameters": { "dataset": "iris", "dataset_name": "iris", "df_name": 1, "folder": "eval", "gp": false, "gp_multiply": false, "model_dir": ".", "model_dir_2": ".", "param_index": 0, "path": "eval/iris/tvae/1", "path_prefix": "../../../../", "repo_index": 5, "save_model": true, "use_all_data": false }, "start_time": "2024-07-25T10:02:14.313679", "version": "2.5.0" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }