diff --git "a/treatment/lct_gan/mlu-eval.ipynb" "b/treatment/lct_gan/mlu-eval.ipynb" new file mode 100644--- /dev/null +++ "b/treatment/lct_gan/mlu-eval.ipynb" @@ -0,0 +1,2380 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "982e76f5", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:54.603965Z", + "iopub.status.busy": "2024-03-22T16:53:54.603594Z", + "iopub.status.idle": "2024-03-22T16:53:54.637506Z", + "shell.execute_reply": "2024-03-22T16:53:54.636558Z" + }, + "papermill": { + "duration": 0.049204, + "end_time": "2024-03-22T16:53:54.639737", + "exception": false, + "start_time": "2024-03-22T16:53:54.590533", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import joblib\n", + "\n", + "#joblib.parallel_backend(\"threading\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "675f0b41", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:54.666476Z", + "iopub.status.busy": "2024-03-22T16:53:54.666093Z", + "iopub.status.idle": "2024-03-22T16:53:54.673364Z", + "shell.execute_reply": "2024-03-22T16:53:54.672349Z" + }, + "papermill": { + "duration": 0.023236, + "end_time": "2024-03-22T16:53:54.675564", + "exception": false, + "start_time": "2024-03-22T16:53:54.652328", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\n%cd /kaggle/working\\n#!git clone https://github.com/R-N/ml-utility-loss --depth=1 --single-branch --branch=main\\n%cd ml-utility-loss\\n!git pull\\n#!pip install .\\n!pip install . --no-deps --force-reinstall --upgrade\\n#'" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "\"\"\"\n", + "%cd /kaggle/working\n", + "#!git clone https://github.com/R-N/ml-utility-loss --depth=1 --single-branch --branch=main\n", + "%cd ml-utility-loss\n", + "!git pull\n", + "#!pip install .\n", + "!pip install . --no-deps --force-reinstall --upgrade\n", + "#\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5ae30f5c", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:54.700593Z", + "iopub.status.busy": "2024-03-22T16:53:54.700315Z", + "iopub.status.idle": "2024-03-22T16:53:54.704453Z", + "shell.execute_reply": "2024-03-22T16:53:54.703667Z" + }, + "papermill": { + "duration": 0.019169, + "end_time": "2024-03-22T16:53:54.706517", + "exception": false, + "start_time": "2024-03-22T16:53:54.687348", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt\n", + "plt.rcParams['figure.figsize'] = [3,3]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9f42c810", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:54.731010Z", + "iopub.status.busy": "2024-03-22T16:53:54.730726Z", + "iopub.status.idle": "2024-03-22T16:53:54.734735Z", + "shell.execute_reply": "2024-03-22T16:53:54.733853Z" + }, + "executionInfo": { + "elapsed": 678, + "status": "ok", + "timestamp": 1696841022168, + "user": { + "displayName": "Rizqi Nur", + "userId": "09644007964068789560" + }, + "user_tz": -420 + }, + "id": "ns5hFcVL2yvs", + "papermill": { + "duration": 0.018617, + "end_time": "2024-03-22T16:53:54.736705", + "exception": false, + "start_time": "2024-03-22T16:53:54.718088", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "datasets = [\n", + " \"insurance\",\n", + " \"treatment\",\n", + " \"contraceptive\"\n", + "]\n", + "\n", + "study_dir = \"./\"" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "85d0c8ce", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:54.760821Z", + "iopub.status.busy": "2024-03-22T16:53:54.760582Z", + "iopub.status.idle": "2024-03-22T16:53:54.766024Z", + "shell.execute_reply": "2024-03-22T16:53:54.765223Z" + }, + "papermill": { + "duration": 0.019584, + "end_time": "2024-03-22T16:53:54.768060", + "exception": false, + "start_time": "2024-03-22T16:53:54.748476", + "status": "completed" + }, + "tags": [ + "parameters" + ] + }, + "outputs": [], + "source": [ + "#Parameters\n", + "import os\n", + "\n", + "path_prefix = \"../../../../\"\n", + "\n", + "dataset_dir = os.path.join(path_prefix, \"ml-utility-loss/datasets\")\n", + "dataset_name = \"treatment\"\n", + "model_name=\"ml_utility_2\"\n", + "models = [\"tvae\", \"realtabformer\", \"lct_gan\", \"tab_ddpm_concat\"]\n", + "single_model = \"lct_gan\"\n", + "random_seed = 42\n", + "gp = True\n", + "gp_multiply = True\n", + "folder = \"eval\"\n", + "debug = False\n", + "path = None\n", + "param_index = 0\n", + "allow_same_prediction = True\n", + "log_wandb = False" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "fe93b2cc", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:54.793738Z", + "iopub.status.busy": "2024-03-22T16:53:54.793435Z", + "iopub.status.idle": "2024-03-22T16:53:54.798469Z", + "shell.execute_reply": "2024-03-22T16:53:54.797700Z" + }, + "papermill": { + "duration": 0.020228, + "end_time": "2024-03-22T16:53:54.800426", + "exception": false, + "start_time": "2024-03-22T16:53:54.780198", + "status": "completed" + }, + "tags": [ + "injected-parameters" + ] + }, + "outputs": [], + "source": [ + "# Parameters\n", + "dataset = \"treatment\"\n", + "dataset_name = \"treatment\"\n", + "single_model = \"lct_gan\"\n", + "gp = True\n", + "gp_multiply = True\n", + "random_seed = 42\n", + "debug = False\n", + "folder = \"eval\"\n", + "path_prefix = \"../../../../\"\n", + "path = \"eval/treatment/lct_gan/42\"\n", + "param_index = 0\n", + "allow_same_prediction = True\n", + "log_wandb = False\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bd7c02d6", + "metadata": { + "papermill": { + "duration": 0.011272, + "end_time": "2024-03-22T16:53:54.823078", + "exception": false, + "start_time": "2024-03-22T16:53:54.811806", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "5f45b1d0", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:54.847715Z", + "iopub.status.busy": "2024-03-22T16:53:54.847006Z", + "iopub.status.idle": "2024-03-22T16:53:54.856170Z", + "shell.execute_reply": "2024-03-22T16:53:54.855346Z" + }, + "executionInfo": { + "elapsed": 7, + "status": "ok", + "timestamp": 1696841022169, + "user": { + "displayName": "Rizqi Nur", + "userId": "09644007964068789560" + }, + "user_tz": -420 + }, + "id": "UdvXYv3c3LXy", + "papermill": { + "duration": 0.023787, + "end_time": "2024-03-22T16:53:54.858208", + "exception": false, + "start_time": "2024-03-22T16:53:54.834421", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "/kaggle/working\n", + "/kaggle/working/eval/treatment/lct_gan/42\n" + ] + } + ], + "source": [ + "from pathlib import Path\n", + "import os\n", + "\n", + "%cd /kaggle/working/\n", + "\n", + "if path is None:\n", + " path = os.path.join(folder, dataset_name, single_model, random_seed)\n", + "Path(path).mkdir(parents=True, exist_ok=True)\n", + "\n", + "%cd {path}" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f85bf540", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:54.882551Z", + "iopub.status.busy": "2024-03-22T16:53:54.881844Z", + "iopub.status.idle": "2024-03-22T16:53:56.921479Z", + "shell.execute_reply": "2024-03-22T16:53:56.920577Z" + }, + "papermill": { + "duration": 2.054095, + "end_time": "2024-03-22T16:53:56.923643", + "exception": false, + "start_time": "2024-03-22T16:53:54.869548", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Set seed to \n" + ] + } + ], + "source": [ + "from ml_utility_loss.util import seed\n", + "if single_model:\n", + " model_name=f\"{model_name}_{single_model}\"\n", + "if random_seed is not None:\n", + " seed(random_seed)\n", + " print(\"Set seed to\", seed)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8489feae", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:56.951621Z", + "iopub.status.busy": "2024-03-22T16:53:56.951153Z", + "iopub.status.idle": "2024-03-22T16:53:56.967827Z", + "shell.execute_reply": "2024-03-22T16:53:56.967069Z" + }, + "papermill": { + "duration": 0.033246, + "end_time": "2024-03-22T16:53:56.969900", + "exception": false, + "start_time": "2024-03-22T16:53:56.936654", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import json\n", + "import os\n", + "\n", + "df = pd.read_csv(os.path.join(dataset_dir, f\"{dataset_name}.csv\"))\n", + "with open(os.path.join(dataset_dir, f\"{dataset_name}.json\")) as f:\n", + " info = json.load(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "debcc684", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:56.994398Z", + "iopub.status.busy": "2024-03-22T16:53:56.994089Z", + "iopub.status.idle": "2024-03-22T16:53:57.002129Z", + "shell.execute_reply": "2024-03-22T16:53:57.001365Z" + }, + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1696841022169, + "user": { + "displayName": "Rizqi Nur", + "userId": "09644007964068789560" + }, + "user_tz": -420 + }, + "id": "Vrl2QkoV3o_8", + "papermill": { + "duration": 0.022576, + "end_time": "2024-03-22T16:53:57.004084", + "exception": false, + "start_time": "2024-03-22T16:53:56.981508", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "task = info[\"task\"]\n", + "target = info[\"target\"]\n", + "cat_features = info[\"cat_features\"]\n", + "mixed_features = info[\"mixed_features\"]\n", + "longtail_features = info[\"longtail_features\"]\n", + "integer_features = info[\"integer_features\"]\n", + "\n", + "test = df.sample(frac=0.2, random_state=42)\n", + "train = df[~df.index.isin(test.index)]" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "7538184a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:57.027681Z", + "iopub.status.busy": "2024-03-22T16:53:57.027415Z", + "iopub.status.idle": "2024-03-22T16:53:57.124714Z", + "shell.execute_reply": "2024-03-22T16:53:57.123657Z" + }, + "executionInfo": { + "elapsed": 6, + "status": "ok", + "timestamp": 1696841022169, + "user": { + "displayName": "Rizqi Nur", + "userId": "09644007964068789560" + }, + "user_tz": -420 + }, + "id": "TilUuFk9vqMb", + "papermill": { + "duration": 0.111825, + "end_time": "2024-03-22T16:53:57.127116", + "exception": false, + "start_time": "2024-03-22T16:53:57.015291", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import ml_utility_loss.synthesizers.tab_ddpm.params as TAB_DDPM_PARAMS\n", + "import ml_utility_loss.synthesizers.lct_gan.params as LCT_GAN_PARAMS\n", + "import ml_utility_loss.synthesizers.realtabformer.params as RTF_PARAMS\n", + "from ml_utility_loss.synthesizers.realtabformer.params.default import GPT2_PARAMS, REALTABFORMER_PARAMS\n", + "from ml_utility_loss.util import filter_dict_2, filter_dict\n", + "\n", + "tab_ddpm_params = getattr(TAB_DDPM_PARAMS, dataset_name).BEST\n", + "lct_gan_params = getattr(LCT_GAN_PARAMS, dataset_name).BEST\n", + "lct_ae_params = filter_dict_2(lct_gan_params, LCT_GAN_PARAMS.default.AE_PARAMS)\n", + "rtf_params = getattr(RTF_PARAMS, dataset_name).BEST\n", + "rtf_params = filter_dict(rtf_params, REALTABFORMER_PARAMS)\n", + "\n", + "lct_ae_embedding_size=lct_gan_params[\"embedding_size\"]\n", + "tab_ddpm_normalization=\"quantile\"\n", + "tab_ddpm_cat_encoding=tab_ddpm_params[\"cat_encoding\"]\n", + "#tab_ddpm_cat_encoding=\"one-hot\"\n", + "tab_ddpm_y_policy=\"default\"\n", + "tab_ddpm_is_y_cond=True" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "cca61838", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:53:57.155124Z", + "iopub.status.busy": "2024-03-22T16:53:57.154827Z", + "iopub.status.idle": "2024-03-22T16:54:01.862723Z", + "shell.execute_reply": "2024-03-22T16:54:01.861696Z" + }, + "executionInfo": { + "elapsed": 3113, + "status": "ok", + "timestamp": 1696841025277, + "user": { + "displayName": "Rizqi Nur", + "userId": "09644007964068789560" + }, + "user_tz": -420 + }, + "id": "7Abt8nStvr9Z", + "papermill": { + "duration": 4.725452, + "end_time": "2024-03-22T16:54:01.865848", + "exception": false, + "start_time": "2024-03-22T16:53:57.140396", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2024-03-22 16:53:59.415157: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "2024-03-22 16:53:59.415216: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "2024-03-22 16:53:59.416897: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n" + ] + } + ], + "source": [ + "from ml_utility_loss.loss_learning.estimator.pipeline import load_lct_ae\n", + "\n", + "lct_ae = load_lct_ae(\n", + " dataset_name=dataset_name,\n", + " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", + " model_name=\"lct_ae\",\n", + " df_name=\"df\",\n", + ")\n", + "lct_ae = None" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "6f83b7b6", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:54:01.897085Z", + "iopub.status.busy": "2024-03-22T16:54:01.896382Z", + "iopub.status.idle": "2024-03-22T16:54:01.903562Z", + "shell.execute_reply": "2024-03-22T16:54:01.902506Z" + }, + "papermill": { + "duration": 0.022721, + "end_time": "2024-03-22T16:54:01.905669", + "exception": false, + "start_time": "2024-03-22T16:54:01.882948", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from ml_utility_loss.loss_learning.estimator.pipeline import load_rtf_embed\n", + "\n", + "rtf_embed = load_rtf_embed(\n", + " dataset_name=dataset_name,\n", + " model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n", + " model_name=\"realtabformer\",\n", + " df_name=\"df\",\n", + " ckpt_type=\"best-disc-model\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "0026de74", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:54:01.933894Z", + "iopub.status.busy": "2024-03-22T16:54:01.933539Z", + "iopub.status.idle": "2024-03-22T16:54:24.491400Z", + "shell.execute_reply": "2024-03-22T16:54:24.490297Z" + }, + "executionInfo": { + "elapsed": 20137, + "status": "ok", + "timestamp": 1696841045408, + "user": { + "displayName": "Rizqi Nur", + "userId": "09644007964068789560" + }, + "user_tz": -420 + }, + "id": "tbaguWxAvtPi", + "papermill": { + "duration": 22.574641, + "end_time": "2024-03-22T16:54:24.494065", + "exception": false, + "start_time": "2024-03-22T16:54:01.919424", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n", + " warnings.warn(\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.10/site-packages/sklearn/preprocessing/_encoders.py:868: FutureWarning: `sparse` was renamed to `sparse_output` in version 1.2 and will be removed in 1.4. `sparse_output` is ignored unless you leave `sparse` to its default value.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "from ml_utility_loss.loss_learning.estimator.preprocessing import DataPreprocessor\n", + "\n", + "preprocessor = DataPreprocessor(\n", + " task,\n", + " target=target,\n", + " cat_features=cat_features,\n", + " mixed_features=mixed_features,\n", + " longtail_features=longtail_features,\n", + " integer_features=integer_features,\n", + " lct_ae_embedding_size=lct_ae_embedding_size,\n", + " lct_ae_params=lct_ae_params,\n", + " lct_ae=lct_ae,\n", + " tab_ddpm_normalization=tab_ddpm_normalization,\n", + " tab_ddpm_cat_encoding=tab_ddpm_cat_encoding,\n", + " tab_ddpm_y_policy=tab_ddpm_y_policy,\n", + " tab_ddpm_is_y_cond=tab_ddpm_is_y_cond,\n", + " realtabformer_embedding=rtf_embed,\n", + " realtabformer_params=rtf_params,\n", + ")\n", + "preprocessor.fit(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "a9c9b110", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "execution": { + "iopub.execute_input": "2024-03-22T16:54:24.524116Z", + "iopub.status.busy": "2024-03-22T16:54:24.523692Z", + "iopub.status.idle": "2024-03-22T16:54:24.530799Z", + "shell.execute_reply": "2024-03-22T16:54:24.529968Z" + }, + "executionInfo": { + "elapsed": 13, + "status": "ok", + "timestamp": 1696841045411, + "user": { + "displayName": "Rizqi Nur", + "userId": "09644007964068789560" + }, + "user_tz": -420 + }, + "id": "OxUH_GBEv2qK", + "outputId": "76464c90-3baf-4bdc-a955-6f4fddc16b9c", + "papermill": { + "duration": 0.024897, + "end_time": "2024-03-22T16:54:24.532947", + "exception": false, + "start_time": "2024-03-22T16:54:24.508050", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'tvae': 95,\n", + " 'realtabformer': (69, 281, Embedding(281, 768), True),\n", + " 'lct_gan': 75,\n", + " 'tab_ddpm_concat': 12}" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "preprocessor.adapter_sizes" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "3cb9ed90", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:54:24.559222Z", + "iopub.status.busy": "2024-03-22T16:54:24.558898Z", + "iopub.status.idle": "2024-03-22T16:54:24.563980Z", + "shell.execute_reply": "2024-03-22T16:54:24.563114Z" + }, + "papermill": { + "duration": 0.020558, + "end_time": "2024-03-22T16:54:24.565917", + "exception": false, + "start_time": "2024-03-22T16:54:24.545359", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from ml_utility_loss.loss_learning.estimator.pipeline import load_dataset_3_factory\n", + "\n", + "datasetsn = load_dataset_3_factory(\n", + " dataset_dir=os.path.join(path_prefix, \"ml-utility-loss/\"),\n", + " dataset_name=dataset_name,\n", + " preprocessor=preprocessor,\n", + " cache_dir=path_prefix,\n", + ")\n" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "ad1eb833", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T16:54:24.591748Z", + "iopub.status.busy": "2024-03-22T16:54:24.591466Z", + "iopub.status.idle": "2024-03-22T17:18:56.244286Z", + "shell.execute_reply": "2024-03-22T17:18:56.243405Z" + }, + "papermill": { + "duration": 1471.680855, + "end_time": "2024-03-22T17:18:56.259077", + "exception": false, + "start_time": "2024-03-22T16:54:24.578222", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Caching in ../../../../treatment/_cache_aug_test/lct_gan/all inf False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "../../../../ml-utility-loss/aug_test/treatment [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n", + "Caching in ../../../../treatment/_cache_bs_test/lct_gan/all inf False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "../../../../ml-utility-loss/bs_test/treatment [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n", + "Caching in ../../../../treatment/_cache_synth_test/lct_gan/all inf False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "../../../../ml-utility-loss/synthetics/treatment [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n", + "1050\n" + ] + } + ], + "source": [ + "from ml_utility_loss.loss_learning.estimator.pipeline import load_dataset_4\n", + "\n", + "test_set = load_dataset_4(\n", + " dataset_dir=os.path.join(path_prefix, \"ml-utility-loss/\"),\n", + " dataset_name=dataset_name,\n", + " preprocessor=preprocessor,\n", + " model=single_model,\n", + " cache_dir=path_prefix,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "14ff8b40", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T17:18:56.287603Z", + "iopub.status.busy": "2024-03-22T17:18:56.287251Z", + "iopub.status.idle": "2024-03-22T17:18:56.608589Z", + "shell.execute_reply": "2024-03-22T17:18:56.607712Z" + }, + "executionInfo": { + "elapsed": 588, + "status": "ok", + "timestamp": 1696841049215, + "user": { + "displayName": "Rizqi Nur", + "userId": "09644007964068789560" + }, + "user_tz": -420 + }, + "id": "NgahtU1q9uLO", + "papermill": { + "duration": 0.338263, + "end_time": "2024-03-22T17:18:56.610783", + "exception": false, + "start_time": "2024-03-22T17:18:56.272520", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Body': 'twin_encoder',\n", + " 'loss_balancer_meta': True,\n", + " 'loss_balancer_log': False,\n", + " 'loss_balancer_lbtw': False,\n", + " 'pma_skip_small': False,\n", + " 'isab_skip_small': False,\n", + " 'layer_norm': False,\n", + " 'pma_layer_norm': False,\n", + " 'attn_residual': True,\n", + " 'tf_n_layers_dec': False,\n", + " 'tf_isab_rank': 0,\n", + " 'tf_layer_norm': False,\n", + " 'tf_pma_start': -1,\n", + " 'head_n_seeds': 0,\n", + " 'tf_pma_low': 16,\n", + " 'dropout': 0,\n", + " 'combine_mode': 'diff_left',\n", + " 'tf_isab_mode': 'separate',\n", + " 'grad_loss_fn': torch.Tensor>,\n", + " 'single_model': True,\n", + " 'bias': True,\n", + " 'bias_final': True,\n", + " 'pma_ffn_mode': 'shared',\n", + " 'patience': 10,\n", + " 'inds_init_mode': 'torch',\n", + " 'grad_clip': 0.8,\n", + " 'gradient_penalty_mode': {'gradient_penalty': True,\n", + " 'forward_once': False,\n", + " 'calc_grad_m': False,\n", + " 'avg_non_role_model_m': False,\n", + " 'inverse_avg_non_role_model_m': False},\n", + " 'synth_data': 2,\n", + " 'bias_lr_mul': 1.0,\n", + " 'bias_weight_decay': 0.1,\n", + " 'loss_balancer_beta': 0.73,\n", + " 'loss_balancer_r': 0.94,\n", + " 'dataset_size': 2048,\n", + " 'batch_size': 4,\n", + " 'epochs': 100,\n", + " 'lr_mul': 0.04,\n", + " 'n_warmup_steps': 220,\n", + " 'Optim': torch_optimizer.diffgrad.DiffGrad,\n", + " 'fixed_role_model': 'lct_gan',\n", + " 'd_model': 512,\n", + " 'attn_activation': ml_utility_loss.activations.LeakyHardsigmoid,\n", + " 'tf_d_inner': 512,\n", + " 'tf_n_layers_enc': 4,\n", + " 'tf_n_head': 64,\n", + " 'tf_activation': ml_utility_loss.activations.LeakyHardtanh,\n", + " 'tf_activation_final': ml_utility_loss.activations.LeakyHardtanh,\n", + " 'ada_d_hid': 1024,\n", + " 'ada_n_layers': 7,\n", + " 'ada_activation': torch.nn.modules.activation.SELU,\n", + " 'ada_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n", + " 'head_d_hid': 128,\n", + " 'head_n_layers': 8,\n", + " 'head_n_head': 64,\n", + " 'head_activation': ml_utility_loss.activations.LeakyHardsigmoid,\n", + " 'head_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n", + " 'models': ['lct_gan'],\n", + " 'max_seconds': 3600,\n", + " 'tf_lora': False,\n", + " 'tf_num_inds': 64,\n", + " 'ada_n_seeds': 0,\n", + " 'gradient_penalty_kwargs': {'mag_loss': True,\n", + " 'mse_mag': True,\n", + " 'mag_corr': False,\n", + " 'seq_mag': False,\n", + " 'cos_loss': False,\n", + " 'mag_corr_kwargs': {'only_sign': False},\n", + " 'cos_loss_kwargs': {'only_sign': True, 'cos_matrix': False},\n", + " 'mse_mag_kwargs': {'target': 0.2, 'multiply': True, 'forgive_over': True}}}" + ] + }, + "execution_count": 18, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import ml_utility_loss.loss_learning.estimator.params2 as PARAMS\n", + "from ml_utility_loss.tuning import map_parameters\n", + "from ml_utility_loss.loss_learning.estimator.params.default import update_param_space, update_param_space_2\n", + "import wandb\n", + "\n", + "#\"\"\"\n", + "param_space = {\n", + " **getattr(PARAMS, dataset_name).PARAM_SPACE,\n", + "}\n", + "params = {\n", + " **getattr(PARAMS, dataset_name).BESTS[param_index],\n", + "}\n", + "if gp:\n", + " params[\"gradient_penalty_mode\"] = \"ALL\"\n", + " params[\"mse_mag\"] = True\n", + " if gp_multiply:\n", + " params[\"mse_mag_multiply\"] = True\n", + " #params[\"mse_mag_target\"] = 1.0\n", + " else:\n", + " params[\"mse_mag_multiply\"] = False\n", + " #params[\"mse_mag_target\"] = 0.1\n", + "else:\n", + " params[\"gradient_penalty_mode\"] = \"NONE\"\n", + " params[\"mse_mag\"] = False\n", + "params[\"single_model\"] = False\n", + "if models:\n", + " params[\"models\"] = models\n", + "if single_model:\n", + " params[\"fixed_role_model\"] = single_model\n", + " params[\"single_model\"] = True\n", + " params[\"models\"] = [single_model]\n", + "if params[\"fixed_role_model\"] == \"realtabformer\" and dataset_name == \"treatment\":\n", + " params[\"batch_size\"] = 2\n", + "params[\"max_seconds\"] = 3600\n", + "params[\"patience\"] = 10\n", + "params[\"epochs\"] = 100\n", + "if debug:\n", + " params[\"epochs\"] = 2\n", + "with open(\"params.json\", \"w\") as f:\n", + " json.dump(params, f)\n", + "params = map_parameters(params, param_space=param_space)\n", + "params" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "a48bd9e9", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T17:18:56.639191Z", + "iopub.status.busy": "2024-03-22T17:18:56.638880Z", + "iopub.status.idle": "2024-03-22T17:50:01.925295Z", + "shell.execute_reply": "2024-03-22T17:50:01.924278Z" + }, + "papermill": { + "duration": 1865.316192, + "end_time": "2024-03-22T17:50:01.940639", + "exception": false, + "start_time": "2024-03-22T17:18:56.624447", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Caching in ../../../../treatment/_cache_aug_train/lct_gan/all inf False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "split df ratio is 0\n", + "../../../../ml-utility-loss/aug_train/treatment [400, 0]\n", + "Caching in ../../../../treatment/_cache_aug_val/lct_gan/all inf False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "split df ratio is 1\n", + "../../../../ml-utility-loss/aug_val/treatment [0, 200]\n", + "Caching in ../../../../treatment/_cache_bs_train/lct_gan/all inf False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "split df ratio is 0\n", + "../../../../ml-utility-loss/bs_train/treatment [100, 0]\n", + "Caching in ../../../../treatment/_cache_bs_val/lct_gan/all inf False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "split df ratio is 1\n", + "../../../../ml-utility-loss/bs_val/treatment [0, 50]\n", + "Caching in ../../../../treatment/_cache_synth/lct_gan/all inf False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Splitting without random!\n", + "Split with reverse index!\n", + "../../../../ml-utility-loss/synthetics/treatment [400, 200]\n", + "[900, 450]\n", + "[900, 450]\n" + ] + } + ], + "source": [ + "train_set, val_set = datasetsn(model=params[\"fixed_role_model\"], synth_data=params[\"synth_data\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "2fcb1418", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "execution": { + "iopub.execute_input": "2024-03-22T17:50:01.971379Z", + "iopub.status.busy": "2024-03-22T17:50:01.970531Z", + "iopub.status.idle": "2024-03-22T17:50:02.499249Z", + "shell.execute_reply": "2024-03-22T17:50:02.498338Z" + }, + "executionInfo": { + "elapsed": 396850, + "status": "error", + "timestamp": 1696841446059, + "user": { + "displayName": "Rizqi Nur", + "userId": "09644007964068789560" + }, + "user_tz": -420 + }, + "id": "_bt1MQc5kpSk", + "outputId": "01c1d3e5-ac64-461d-835a-b76f4a66e6d6", + "papermill": { + "duration": 0.546432, + "end_time": "2024-03-22T17:50:02.501385", + "exception": false, + "start_time": "2024-03-22T17:50:01.954953", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Creating model of type \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[*] Embedding False True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['lct_gan'] 1\n" + ] + } + ], + "source": [ + "from ml_utility_loss.loss_learning.estimator.model.pipeline import remove_non_model_params\n", + "from ml_utility_loss.loss_learning.estimator.pipeline import create_model\n", + "from ml_utility_loss.util import filter_dict, clear_memory\n", + "\n", + "clear_memory()\n", + "\n", + "params2 = remove_non_model_params(params)\n", + "adapters = filter_dict(preprocessor.adapter_sizes, params[\"models\"])\n", + "\n", + "model = create_model(\n", + " adapters=adapters,\n", + " #Body=\"twin_encoder\",\n", + " **params2,\n", + ")\n", + "#cf.apply_weight_standardization(model, n_last_layers_ignore=0)\n", + "print(model.models, len(model.adapters))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "938f94fc", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T17:50:02.532791Z", + "iopub.status.busy": "2024-03-22T17:50:02.532446Z", + "iopub.status.idle": "2024-03-22T17:50:02.536685Z", + "shell.execute_reply": "2024-03-22T17:50:02.535782Z" + }, + "papermill": { + "duration": 0.023513, + "end_time": "2024-03-22T17:50:02.539641", + "exception": false, + "start_time": "2024-03-22T17:50:02.516128", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "study_name=f\"{model_name}_{dataset_name}\"" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "12fb613e", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T17:50:02.568171Z", + "iopub.status.busy": "2024-03-22T17:50:02.567868Z", + "iopub.status.idle": "2024-03-22T17:50:02.574874Z", + "shell.execute_reply": "2024-03-22T17:50:02.573995Z" + }, + "papermill": { + "duration": 0.023652, + "end_time": "2024-03-22T17:50:02.576882", + "exception": false, + "start_time": "2024-03-22T17:50:02.553230", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "18680833" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "def count_parameters(model):\n", + " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", + "\n", + "count_parameters(model)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "bd386e57", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T17:50:02.605328Z", + "iopub.status.busy": "2024-03-22T17:50:02.605063Z", + "iopub.status.idle": "2024-03-22T17:50:02.699743Z", + "shell.execute_reply": "2024-03-22T17:50:02.698806Z" + }, + "papermill": { + "duration": 0.111362, + "end_time": "2024-03-22T17:50:02.702059", + "exception": false, + "start_time": "2024-03-22T17:50:02.590697", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "========================================================================================================================\n", + "Layer (type:depth-idx) Output Shape Param #\n", + "========================================================================================================================\n", + "MLUtilitySingle [2, 2648, 75] --\n", + "├─Adapter: 1-1 [2, 2648, 75] --\n", + "│ └─Sequential: 2-1 [2, 2648, 512] --\n", + "│ │ └─FeedForward: 3-1 [2, 2648, 1024] --\n", + "│ │ │ └─Linear: 4-1 [2, 2648, 1024] 77,824\n", + "│ │ │ └─SELU: 4-2 [2, 2648, 1024] --\n", + "│ │ └─FeedForward: 3-2 [2, 2648, 1024] --\n", + "│ │ │ └─Linear: 4-3 [2, 2648, 1024] 1,049,600\n", + "│ │ │ └─SELU: 4-4 [2, 2648, 1024] --\n", + "│ │ └─FeedForward: 3-3 [2, 2648, 1024] --\n", + "│ │ │ └─Linear: 4-5 [2, 2648, 1024] 1,049,600\n", + "│ │ │ └─SELU: 4-6 [2, 2648, 1024] --\n", + "│ │ └─FeedForward: 3-4 [2, 2648, 1024] --\n", + "│ │ │ └─Linear: 4-7 [2, 2648, 1024] 1,049,600\n", + "│ │ │ └─SELU: 4-8 [2, 2648, 1024] --\n", + "│ │ └─FeedForward: 3-5 [2, 2648, 1024] --\n", + "│ │ │ └─Linear: 4-9 [2, 2648, 1024] 1,049,600\n", + "│ │ │ └─SELU: 4-10 [2, 2648, 1024] --\n", + "│ │ └─FeedForward: 3-6 [2, 2648, 1024] --\n", + "│ │ │ └─Linear: 4-11 [2, 2648, 1024] 1,049,600\n", + "│ │ │ └─SELU: 4-12 [2, 2648, 1024] --\n", + "│ │ └─FeedForward: 3-7 [2, 2648, 512] --\n", + "│ │ │ └─Linear: 4-13 [2, 2648, 512] 524,800\n", + "│ │ │ └─LeakyHardsigmoid: 4-14 [2, 2648, 512] --\n", + "├─Adapter: 1-2 [2, 661, 75] (recursive)\n", + "│ └─Sequential: 2-2 [2, 661, 512] (recursive)\n", + "│ │ └─FeedForward: 3-8 [2, 661, 1024] (recursive)\n", + "│ │ │ └─Linear: 4-15 [2, 661, 1024] (recursive)\n", + "│ │ │ └─SELU: 4-16 [2, 661, 1024] --\n", + "│ │ └─FeedForward: 3-9 [2, 661, 1024] (recursive)\n", + "│ │ │ └─Linear: 4-17 [2, 661, 1024] (recursive)\n", + "│ │ │ └─SELU: 4-18 [2, 661, 1024] --\n", + "│ │ └─FeedForward: 3-10 [2, 661, 1024] (recursive)\n", + "│ │ │ └─Linear: 4-19 [2, 661, 1024] (recursive)\n", + "│ │ │ └─SELU: 4-20 [2, 661, 1024] --\n", + "│ │ └─FeedForward: 3-11 [2, 661, 1024] (recursive)\n", + "│ │ │ └─Linear: 4-21 [2, 661, 1024] (recursive)\n", + "│ │ │ └─SELU: 4-22 [2, 661, 1024] --\n", + "│ │ └─FeedForward: 3-12 [2, 661, 1024] (recursive)\n", + "│ │ │ └─Linear: 4-23 [2, 661, 1024] (recursive)\n", + "│ │ │ └─SELU: 4-24 [2, 661, 1024] --\n", + "│ │ └─FeedForward: 3-13 [2, 661, 1024] (recursive)\n", + "│ │ │ └─Linear: 4-25 [2, 661, 1024] (recursive)\n", + "│ │ │ └─SELU: 4-26 [2, 661, 1024] --\n", + "│ │ └─FeedForward: 3-14 [2, 661, 512] (recursive)\n", + "│ │ │ └─Linear: 4-27 [2, 661, 512] (recursive)\n", + "│ │ │ └─LeakyHardsigmoid: 4-28 [2, 661, 512] --\n", + "├─TwinEncoder: 1-3 [2, 8192] --\n", + "│ └─Encoder: 2-3 [2, 16, 512] --\n", + "│ │ └─ModuleList: 3-16 -- (recursive)\n", + "│ │ │ └─EncoderLayer: 4-29 [2, 2648, 512] --\n", + "│ │ │ │ └─SimpleInducedSetAttention: 5-1 [2, 2648, 512] --\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-1 [2, 64, 512] 32,768\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-2 [2, 64, 512] --\n", + "│ │ │ │ │ │ └─Linear: 7-1 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-2 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-3 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-4 [2, 64, 64, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-1 [2, 64, 64, 2648] --\n", + "│ │ │ │ │ │ └─Linear: 7-5 [2, 64, 512] 262,656\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-6 [2, 64, 512] --\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-3 [2, 2648, 512] --\n", + "│ │ │ │ │ │ └─Linear: 7-7 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-8 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-9 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-10 [2, 64, 2648, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-2 [2, 64, 2648, 64] --\n", + "│ │ │ │ │ │ └─Linear: 7-11 [2, 2648, 512] 262,656\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-12 [2, 2648, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-2 [2, 2648, 512] --\n", + "│ │ │ │ │ └─Linear: 6-4 [2, 2648, 512] 262,656\n", + "│ │ │ │ │ └─LeakyHardtanh: 6-5 [2, 2648, 512] --\n", + "│ │ │ │ │ └─Linear: 6-6 [2, 2648, 512] 262,656\n", + "│ │ │ └─EncoderLayer: 4-30 [2, 2648, 512] --\n", + "│ │ │ │ └─SimpleInducedSetAttention: 5-3 [2, 2648, 512] --\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-7 [2, 64, 512] 32,768\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-8 [2, 64, 512] --\n", + "│ │ │ │ │ │ └─Linear: 7-13 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-14 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-15 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-16 [2, 64, 64, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-3 [2, 64, 64, 2648] --\n", + "│ │ │ │ │ │ └─Linear: 7-17 [2, 64, 512] 262,656\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-18 [2, 64, 512] --\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-9 [2, 2648, 512] --\n", + "│ │ │ │ │ │ └─Linear: 7-19 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-20 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-21 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-22 [2, 64, 2648, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-4 [2, 64, 2648, 64] --\n", + "│ │ │ │ │ │ └─Linear: 7-23 [2, 2648, 512] 262,656\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-24 [2, 2648, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-4 [2, 2648, 512] --\n", + "│ │ │ │ │ └─Linear: 6-10 [2, 2648, 512] 262,656\n", + "│ │ │ │ │ └─LeakyHardtanh: 6-11 [2, 2648, 512] --\n", + "│ │ │ │ │ └─Linear: 6-12 [2, 2648, 512] 262,656\n", + "│ │ │ └─EncoderLayer: 4-31 [2, 2648, 512] --\n", + "│ │ │ │ └─SimpleInducedSetAttention: 5-5 [2, 2648, 512] --\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-13 [2, 64, 512] 32,768\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-14 [2, 64, 512] --\n", + "│ │ │ │ │ │ └─Linear: 7-25 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-26 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-27 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-28 [2, 64, 64, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-5 [2, 64, 64, 2648] --\n", + "│ │ │ │ │ │ └─Linear: 7-29 [2, 64, 512] 262,656\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-30 [2, 64, 512] --\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-15 [2, 2648, 512] --\n", + "│ │ │ │ │ │ └─Linear: 7-31 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-32 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-33 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-34 [2, 64, 2648, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-6 [2, 64, 2648, 64] --\n", + "│ │ │ │ │ │ └─Linear: 7-35 [2, 2648, 512] 262,656\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-36 [2, 2648, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-6 [2, 2648, 512] --\n", + "│ │ │ │ │ └─Linear: 6-16 [2, 2648, 512] 262,656\n", + "│ │ │ │ │ └─LeakyHardtanh: 6-17 [2, 2648, 512] --\n", + "│ │ │ │ │ └─Linear: 6-18 [2, 2648, 512] 262,656\n", + "│ │ │ └─EncoderLayer: 4-32 [2, 16, 512] --\n", + "│ │ │ │ └─SimpleInducedSetAttention: 5-7 [2, 2648, 512] --\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-19 [2, 64, 512] 32,768\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-20 [2, 64, 512] --\n", + "│ │ │ │ │ │ └─Linear: 7-37 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-38 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-39 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-40 [2, 64, 64, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-7 [2, 64, 64, 2648] --\n", + "│ │ │ │ │ │ └─Linear: 7-41 [2, 64, 512] 262,656\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-42 [2, 64, 512] --\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-21 [2, 2648, 512] --\n", + "│ │ │ │ │ │ └─Linear: 7-43 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-44 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-45 [2, 64, 512] 262,144\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-46 [2, 64, 2648, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-8 [2, 64, 2648, 64] --\n", + "│ │ │ │ │ │ └─Linear: 7-47 [2, 2648, 512] 262,656\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-48 [2, 2648, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-8 [2, 2648, 512] --\n", + "│ │ │ │ │ └─Linear: 6-22 [2, 2648, 512] 262,656\n", + "│ │ │ │ │ └─LeakyHardtanh: 6-23 [2, 2648, 512] --\n", + "│ │ │ │ │ └─Linear: 6-24 [2, 2648, 512] 262,656\n", + "│ │ │ │ └─PoolingByMultiheadAttention: 5-9 [2, 16, 512] --\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-25 [2, 16, 512] 8,192\n", + "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-26 [2, 16, 512] --\n", + "│ │ │ │ │ │ └─Linear: 7-49 [2, 16, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-50 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─Linear: 7-51 [2, 2648, 512] 262,144\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-52 [2, 64, 16, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-9 [2, 64, 16, 2648] --\n", + "│ │ │ │ │ │ └─Linear: 7-53 [2, 16, 512] 262,656\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-54 [2, 16, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-10 [2, 16, 512] (recursive)\n", + "│ │ │ │ │ └─Linear: 6-27 [2, 16, 512] (recursive)\n", + "│ │ │ │ │ └─LeakyHardtanh: 6-28 [2, 16, 512] --\n", + "│ │ │ │ │ └─Linear: 6-29 [2, 16, 512] (recursive)\n", + "│ └─Encoder: 2-4 [2, 16, 512] (recursive)\n", + "│ │ └─ModuleList: 3-16 -- (recursive)\n", + "│ │ │ └─EncoderLayer: 4-33 [2, 661, 512] (recursive)\n", + "│ │ │ │ └─SimpleInducedSetAttention: 5-11 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-30 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-31 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-55 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-56 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-57 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-58 [2, 64, 64, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-10 [2, 64, 64, 661] --\n", + "│ │ │ │ │ │ └─Linear: 7-59 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-60 [2, 64, 512] --\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-32 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-61 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-62 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-63 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-64 [2, 64, 661, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-11 [2, 64, 661, 64] --\n", + "│ │ │ │ │ │ └─Linear: 7-65 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-66 [2, 661, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-12 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─Linear: 6-33 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─LeakyHardtanh: 6-34 [2, 661, 512] --\n", + "│ │ │ │ │ └─Linear: 6-35 [2, 661, 512] (recursive)\n", + "│ │ │ └─EncoderLayer: 4-34 [2, 661, 512] (recursive)\n", + "│ │ │ │ └─SimpleInducedSetAttention: 5-13 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-36 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-37 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-67 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-68 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-69 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-70 [2, 64, 64, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-12 [2, 64, 64, 661] --\n", + "│ │ │ │ │ │ └─Linear: 7-71 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-72 [2, 64, 512] --\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-38 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-73 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-74 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-75 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-76 [2, 64, 661, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-13 [2, 64, 661, 64] --\n", + "│ │ │ │ │ │ └─Linear: 7-77 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-78 [2, 661, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-14 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─Linear: 6-39 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─LeakyHardtanh: 6-40 [2, 661, 512] --\n", + "│ │ │ │ │ └─Linear: 6-41 [2, 661, 512] (recursive)\n", + "│ │ │ └─EncoderLayer: 4-35 [2, 661, 512] (recursive)\n", + "│ │ │ │ └─SimpleInducedSetAttention: 5-15 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-42 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-43 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-79 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-80 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-81 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-82 [2, 64, 64, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-14 [2, 64, 64, 661] --\n", + "│ │ │ │ │ │ └─Linear: 7-83 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-84 [2, 64, 512] --\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-44 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-85 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-86 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-87 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-88 [2, 64, 661, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-15 [2, 64, 661, 64] --\n", + "│ │ │ │ │ │ └─Linear: 7-89 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-90 [2, 661, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-16 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─Linear: 6-45 [2, 661, 512] (recursive)\n", + "│ │ │ �� │ └─LeakyHardtanh: 6-46 [2, 661, 512] --\n", + "│ │ │ │ │ └─Linear: 6-47 [2, 661, 512] (recursive)\n", + "│ │ │ └─EncoderLayer: 4-36 [2, 16, 512] (recursive)\n", + "│ │ │ │ └─SimpleInducedSetAttention: 5-17 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-48 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-49 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-91 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-92 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-93 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-94 [2, 64, 64, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-16 [2, 64, 64, 661] --\n", + "│ │ │ │ │ │ └─Linear: 7-95 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-96 [2, 64, 512] --\n", + "│ │ │ │ │ └─MultiHeadAttention: 6-50 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-97 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-98 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-99 [2, 64, 512] (recursive)\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-100 [2, 64, 661, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-17 [2, 64, 661, 64] --\n", + "│ │ │ │ │ │ └─Linear: 7-101 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-102 [2, 661, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-18 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─Linear: 6-51 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ └─LeakyHardtanh: 6-52 [2, 661, 512] --\n", + "│ │ │ │ │ └─Linear: 6-53 [2, 661, 512] (recursive)\n", + "│ │ │ │ └─PoolingByMultiheadAttention: 5-19 [2, 16, 512] (recursive)\n", + "│ │ │ │ │ └─TensorInductionPoint: 6-54 [2, 16, 512] (recursive)\n", + "│ │ │ │ │ └─SimpleMultiHeadAttention: 6-55 [2, 16, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-103 [2, 16, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-104 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─Linear: 7-105 [2, 661, 512] (recursive)\n", + "│ │ │ │ │ │ └─ScaledDotProductAttention: 7-106 [2, 64, 16, 8] --\n", + "│ │ │ │ │ │ │ └─Softmax: 8-18 [2, 64, 16, 661] --\n", + "│ │ │ │ │ │ └─Linear: 7-107 [2, 16, 512] (recursive)\n", + "│ │ │ │ │ │ └─LeakyHardsigmoid: 7-108 [2, 16, 512] --\n", + "│ │ │ │ └─DoubleFeedForward: 5-20 [2, 16, 512] (recursive)\n", + "│ │ │ │ │ └─Linear: 6-56 [2, 16, 512] (recursive)\n", + "│ │ │ │ │ └─LeakyHardtanh: 6-57 [2, 16, 512] --\n", + "│ │ │ │ │ └─Linear: 6-58 [2, 16, 512] (recursive)\n", + "├─Head: 1-4 [2] --\n", + "│ └─Sequential: 2-5 [2, 1] --\n", + "│ │ └─FeedForward: 3-17 [2, 128] --\n", + "│ │ │ └─Linear: 4-37 [2, 128] 1,048,704\n", + "│ │ │ └─LeakyHardsigmoid: 4-38 [2, 128] --\n", + "│ │ └─FeedForward: 3-18 [2, 128] --\n", + "│ │ │ └─Linear: 4-39 [2, 128] 16,512\n", + "│ │ │ └─LeakyHardsigmoid: 4-40 [2, 128] --\n", + "│ │ └─FeedForward: 3-19 [2, 128] --\n", + "│ │ │ └─Linear: 4-41 [2, 128] 16,512\n", + "│ │ │ └─LeakyHardsigmoid: 4-42 [2, 128] --\n", + "│ │ └─FeedForward: 3-20 [2, 128] --\n", + "│ │ │ └─Linear: 4-43 [2, 128] 16,512\n", + "│ │ │ └─LeakyHardsigmoid: 4-44 [2, 128] --\n", + "│ │ └─FeedForward: 3-21 [2, 128] --\n", + "│ │ │ └─Linear: 4-45 [2, 128] 16,512\n", + "│ │ │ └─LeakyHardsigmoid: 4-46 [2, 128] --\n", + "│ │ └─FeedForward: 3-22 [2, 128] --\n", + "│ │ │ └─Linear: 4-47 [2, 128] 16,512\n", + "│ │ │ └─LeakyHardsigmoid: 4-48 [2, 128] --\n", + "│ │ └─FeedForward: 3-23 [2, 128] --\n", + "│ │ │ └─Linear: 4-49 [2, 128] 16,512\n", + "│ │ │ └─LeakyHardsigmoid: 4-50 [2, 128] --\n", + "│ │ └─FeedForward: 3-24 [2, 1] --\n", + "│ │ │ └─Linear: 4-51 [2, 1] 129\n", + "│ │ │ └─LeakyHardsigmoid: 4-52 [2, 1] --\n", + "========================================================================================================================\n", + "Total params: 18,680,833\n", + "Trainable params: 18,680,833\n", + "Non-trainable params: 0\n", + "Total mult-adds (M): 73.97\n", + "========================================================================================================================\n", + "Input size (MB): 1.99\n", + "Forward/backward pass size (MB): 1079.48\n", + "Params size (MB): 74.72\n", + "Estimated Total Size (MB): 1156.19\n", + "========================================================================================================================" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from torchinfo import summary\n", + "\n", + "role_model = params[\"fixed_role_model\"]\n", + "s = train_set[0][role_model]\n", + "summary(model[role_model], input_size=((2, *s[0].shape), (2, *s[1].shape)), depth=9) # 8 max" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "0f42c4d1", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T17:50:02.735226Z", + "iopub.status.busy": "2024-03-22T17:50:02.734892Z", + "iopub.status.idle": "2024-03-22T19:00:38.391656Z", + "shell.execute_reply": "2024-03-22T19:00:38.390642Z" + }, + "papermill": { + "duration": 4235.677039, + "end_time": "2024-03-22T19:00:38.394746", + "exception": false, + "start_time": "2024-03-22T17:50:02.717707", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3 datasets [900, 450, 1050]\n", + "Creating model of type \n", + "[*] Embedding False True\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "g_loss_mul 0.1\n", + "Epoch 0\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train loss {'avg_role_model_loss': 0.08288680652415173, 'avg_role_model_std_loss': 14.884825719568891, 'avg_role_model_mean_pred_loss': 0.014639717982716328, 'avg_role_model_g_mag_loss': 0.007401900000145865, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.406318084973221, 'n_size': 900, 'n_batch': 225, 'duration': 398.05458784103394, 'duration_batch': 1.7691315015157063, 'duration_size': 0.4422828753789266, 'avg_pred_std': 0.04329945017169747}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val loss {'avg_role_model_loss': 0.01646603072894332, 'avg_role_model_std_loss': 1.1389748531219954, 'avg_role_model_mean_pred_loss': 0.001041307360675295, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01646603072894332, 'n_size': 450, 'n_batch': 113, 'duration': 104.75376605987549, 'duration_batch': 0.9270244784059778, 'duration_size': 0.23278614679972331, 'avg_pred_std': 0.1397827780404771}\n", + "Epoch 1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train loss {'avg_role_model_loss': 0.014255170632645281, 'avg_role_model_std_loss': 0.24843680457298686, 'avg_role_model_mean_pred_loss': 0.000636702568035041, 'avg_role_model_g_mag_loss': 0.15559853479266167, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.014447312605981198, 'n_size': 900, 'n_batch': 225, 'duration': 406.69940519332886, 'duration_batch': 1.8075529119703504, 'duration_size': 0.4518882279925876, 'avg_pred_std': 0.22602122453765736}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val loss {'avg_role_model_loss': 0.011084747278914115, 'avg_role_model_std_loss': 0.774564493527634, 'avg_role_model_mean_pred_loss': 0.00036442376687365014, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011084747278914115, 'n_size': 450, 'n_batch': 113, 'duration': 105.70433187484741, 'duration_batch': 0.9354365652641364, 'duration_size': 0.23489851527743869, 'avg_pred_std': 0.13189282911609182}\n", + "Epoch 2\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train loss {'avg_role_model_loss': 0.009175471019561883, 'avg_role_model_std_loss': 0.14389958513339807, 'avg_role_model_mean_pred_loss': 0.0003093209020328993, 'avg_role_model_g_mag_loss': 0.16885624952562567, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00928754332613001, 'n_size': 900, 'n_batch': 225, 'duration': 406.6975419521332, 'duration_batch': 1.8075446308983696, 'duration_size': 0.4518861577245924, 'avg_pred_std': 0.23166270198714403}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val loss {'avg_role_model_loss': 0.00785602382393045, 'avg_role_model_std_loss': 1.3146459211567783, 'avg_role_model_mean_pred_loss': 0.0001525110592525784, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00785602382393045, 'n_size': 450, 'n_batch': 113, 'duration': 105.51167917251587, 'duration_batch': 0.9337316740930608, 'duration_size': 0.23447039816114637, 'avg_pred_std': 0.11703323267914861}\n", + "Epoch 3\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train loss {'avg_role_model_loss': 0.006946129192502769, 'avg_role_model_std_loss': 0.16631651740127382, 'avg_role_model_mean_pred_loss': 0.00011629207969720338, 'avg_role_model_g_mag_loss': 0.1469966934973167, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.007030756754486194, 'n_size': 900, 'n_batch': 225, 'duration': 404.7738826274872, 'duration_batch': 1.798995033899943, 'duration_size': 0.44974875847498574, 'avg_pred_std': 0.23086450531949393}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val loss {'avg_role_model_loss': 0.006420711345431553, 'avg_role_model_std_loss': 1.1872994360646385, 'avg_role_model_mean_pred_loss': 8.72488180067034e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.006420711345431553, 'n_size': 450, 'n_batch': 113, 'duration': 103.53712010383606, 'duration_batch': 0.9162577000339475, 'duration_size': 0.2300824891196357, 'avg_pred_std': 0.1300600932704221}\n", + "Epoch 4\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train loss {'avg_role_model_loss': 0.005421485699508695, 'avg_role_model_std_loss': 0.08202415918292068, 'avg_role_model_mean_pred_loss': 0.00010085892813725515, 'avg_role_model_g_mag_loss': 0.13527608269825578, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.005492081102662875, 'n_size': 900, 'n_batch': 225, 'duration': 403.6879127025604, 'duration_batch': 1.7941685009002686, 'duration_size': 0.44854212522506715, 'avg_pred_std': 0.23655629260775943}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val loss {'avg_role_model_loss': 0.009256900170618995, 'avg_role_model_std_loss': 2.8233697297467697, 'avg_role_model_mean_pred_loss': 0.0006465031184101571, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.009256900170618995, 'n_size': 450, 'n_batch': 113, 'duration': 101.8208520412445, 'duration_batch': 0.9010694870906594, 'duration_size': 0.22626856009165447, 'avg_pred_std': 0.13276277751503135}\n", + "Epoch 5\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train loss {'avg_role_model_loss': 0.0043073366452492535, 'avg_role_model_std_loss': 0.19139563928810466, 'avg_role_model_mean_pred_loss': 8.40804457109845e-05, 'avg_role_model_g_mag_loss': 0.12489897313269062, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0043637448957360905, 'n_size': 900, 'n_batch': 225, 'duration': 398.688072681427, 'duration_batch': 1.7719469896952311, 'duration_size': 0.4429867474238078, 'avg_pred_std': 0.23179641341906973}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val loss {'avg_role_model_loss': 0.0055422348428186925, 'avg_role_model_std_loss': 1.4432227016979617, 'avg_role_model_mean_pred_loss': 5.823207058893942e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0055422348428186925, 'n_size': 450, 'n_batch': 113, 'duration': 101.09947466850281, 'duration_batch': 0.8946856165354231, 'duration_size': 0.22466549926333956, 'avg_pred_std': 0.11363168490392674}\n", + "Epoch 6\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train loss {'avg_role_model_loss': 0.003263339866756117, 'avg_role_model_std_loss': 0.3763446019548343, 'avg_role_model_mean_pred_loss': 2.3110432702086737e-05, 'avg_role_model_g_mag_loss': 0.10734513330583771, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0033083394544084713, 'n_size': 900, 'n_batch': 225, 'duration': 398.69307565689087, 'duration_batch': 1.7719692251417372, 'duration_size': 0.4429923062854343, 'avg_pred_std': 0.24115597604735134}\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Val loss {'avg_role_model_loss': 0.00555163197664519, 'avg_role_model_std_loss': 1.1181977761929087, 'avg_role_model_mean_pred_loss': 0.00010318647899446903, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.00555163197664519, 'n_size': 450, 'n_batch': 113, 'duration': 101.25119686126709, 'duration_batch': 0.8960282908076733, 'duration_size': 0.22500265969170463, 'avg_pred_std': 0.11773095095579861}\n", + "Epoch 7\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train loss {'avg_role_model_loss': 0.0027152151927152266, 'avg_role_model_std_loss': 0.07267851859859407, 'avg_role_model_mean_pred_loss': 1.4569837801249912e-05, 'avg_role_model_g_mag_loss': 0.09808282882389095, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0027543007831668043, 'n_size': 900, 'n_batch': 225, 'duration': 400.2191047668457, 'duration_batch': 1.7787515767415365, 'duration_size': 0.4446878941853841, 'avg_pred_std': 0.23774532583390182}\n", + "Time out: 3949.559079647064/3600\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Eval loss {'role_model': 'lct_gan', 'n_size': 1050, 'n_batch': 263, 'role_model_metrics': {'avg_loss': 0.0047467758956138695, 'avg_g_mag_loss': 0.016803590516231294, 'avg_g_cos_loss': 0.0, 'pred_duration': 6.370237112045288, 'grad_duration': 11.638083219528198, 'total_duration': 18.008320331573486, 'pred_std': 0.23633873462677002, 'std_loss': 0.0008537429966963828, 'mean_pred_loss': 3.869256761390716e-05, 'pred_rmse': 0.06889685243368149, 'pred_mae': 0.04724160581827164, 'pred_mape': 4645955.0, 'grad_rmse': 0.14325737953186035, 'grad_mae': 0.09063062816858292, 'grad_mape': 1.6706515550613403}, 'non_role_model_metrics': {'avg_loss': 0, 'avg_g_mag_loss': 0, 'avg_g_cos_loss': 0, 'avg_pred_duration': 0, 'avg_grad_duration': 0, 'avg_total_duration': 0, 'avg_pred_std': 0, 'avg_std_loss': 0, 'avg_mean_pred_loss': 0}, 'avg_metrics': {'avg_loss': 0.0047467758956138695, 'avg_g_mag_loss': 0.016803590516231294, 'avg_g_cos_loss': 0.0, 'avg_pred_duration': 6.370237112045288, 'avg_grad_duration': 11.638083219528198, 'avg_total_duration': 18.008320331573486, 'avg_pred_std': 0.23633873462677002, 'avg_std_loss': 0.0008537429966963828, 'avg_mean_pred_loss': 3.869256761390716e-05}, 'min_metrics': {'avg_loss': 0.0047467758956138695, 'avg_g_mag_loss': 0.016803590516231294, 'avg_g_cos_loss': 0.0, 'pred_duration': 6.370237112045288, 'grad_duration': 11.638083219528198, 'total_duration': 18.008320331573486, 'pred_std': 0.23633873462677002, 'std_loss': 0.0008537429966963828, 'mean_pred_loss': 3.869256761390716e-05, 'pred_rmse': 0.06889685243368149, 'pred_mae': 0.04724160581827164, 'pred_mape': 4645955.0, 'grad_rmse': 0.14325737953186035, 'grad_mae': 0.09063062816858292, 'grad_mape': 1.6706515550613403}, 'model_metrics': {'lct_gan': {'avg_loss': 0.0047467758956138695, 'avg_g_mag_loss': 0.016803590516231294, 'avg_g_cos_loss': 0.0, 'pred_duration': 6.370237112045288, 'grad_duration': 11.638083219528198, 'total_duration': 18.008320331573486, 'pred_std': 0.23633873462677002, 'std_loss': 0.0008537429966963828, 'mean_pred_loss': 3.869256761390716e-05, 'pred_rmse': 0.06889685243368149, 'pred_mae': 0.04724160581827164, 'pred_mape': 4645955.0, 'grad_rmse': 0.14325737953186035, 'grad_mae': 0.09063062816858292, 'grad_mape': 1.6706515550613403}}}\n" + ] + } + ], + "source": [ + "import torch\n", + "from ml_utility_loss.loss_learning.estimator.pipeline import train, train_2\n", + "from ml_utility_loss.loss_learning.estimator.process_simple import train_epoch, eval as _eval\n", + "from ml_utility_loss.params import GradientPenaltyMode\n", + "from ml_utility_loss.util import clear_memory\n", + "import time\n", + "#torch.autograd.set_detect_anomaly(True)\n", + "\n", + "del model\n", + "clear_memory()\n", + "\n", + "#opt = params[\"Optim\"](model.parameters())\n", + "loss = train_2(\n", + " [train_set, val_set, test_set],\n", + " preprocessor=preprocessor,\n", + " #whole_model=model,\n", + " #optim=opt,\n", + " log_dir=\"logs\",\n", + " checkpoint_dir=\"checkpoints\",\n", + " verbose=True,\n", + " allow_same_prediction=allow_same_prediction,\n", + " wandb=wandb if log_wandb else None,\n", + " study_name=study_name,\n", + " **params\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "9b514a07", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:00:38.430630Z", + "iopub.status.busy": "2024-03-22T19:00:38.429858Z", + "iopub.status.idle": "2024-03-22T19:00:38.434341Z", + "shell.execute_reply": "2024-03-22T19:00:38.433435Z" + }, + "papermill": { + "duration": 0.024595, + "end_time": "2024-03-22T19:00:38.436245", + "exception": false, + "start_time": "2024-03-22T19:00:38.411650", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "model = loss[\"whole_model\"]\n", + "opt = loss[\"optim\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "331a49e1", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:00:38.469769Z", + "iopub.status.busy": "2024-03-22T19:00:38.469456Z", + "iopub.status.idle": "2024-03-22T19:00:38.596113Z", + "shell.execute_reply": "2024-03-22T19:00:38.595324Z" + }, + "papermill": { + "duration": 0.146222, + "end_time": "2024-03-22T19:00:38.598627", + "exception": false, + "start_time": "2024-03-22T19:00:38.452405", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import torch\n", + "from copy import deepcopy\n", + "\n", + "torch.save(deepcopy(model.state_dict()), \"model.pt\")\n", + "#torch.save(deepcopy(opt.state_dict()), \"optim.pt\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "123b4b17", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:00:38.635319Z", + "iopub.status.busy": "2024-03-22T19:00:38.634452Z", + "iopub.status.idle": "2024-03-22T19:00:38.905876Z", + "shell.execute_reply": "2024-03-22T19:00:38.904973Z" + }, + "papermill": { + "duration": 0.291967, + "end_time": "2024-03-22T19:00:38.907984", + "exception": false, + "start_time": "2024-03-22T19:00:38.616017", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "history = loss[\"history\"]\n", + "history.to_csv(\"history.csv\")\n", + "history[[\"avg_loss_train\", \"avg_loss_test\"]].plot()" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "2586ba0a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:00:38.945124Z", + "iopub.status.busy": "2024-03-22T19:00:38.944791Z", + "iopub.status.idle": "2024-03-22T19:05:25.592731Z", + "shell.execute_reply": "2024-03-22T19:05:25.591881Z" + }, + "papermill": { + "duration": 286.669422, + "end_time": "2024-03-22T19:05:25.595196", + "exception": false, + "start_time": "2024-03-22T19:00:38.925774", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "\n", + "from ml_utility_loss.loss_learning.estimator.pipeline import eval\n", + "#eval_loss = loss[\"eval_loss\"]\n", + "\n", + "batch_size = params[\"batch_size_low\"] if \"batch_size_low\" in params else params[\"batch_size\"]\n", + "\n", + "eval_loss = eval(\n", + " test_set, model,\n", + " batch_size=batch_size,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "187137f6", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:05:25.632815Z", + "iopub.status.busy": "2024-03-22T19:05:25.631976Z", + "iopub.status.idle": "2024-03-22T19:05:25.653483Z", + "shell.execute_reply": "2024-03-22T19:05:25.652522Z" + }, + "papermill": { + "duration": 0.042389, + "end_time": "2024-03-22T19:05:25.655469", + "exception": false, + "start_time": "2024-03-22T19:05:25.613080", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
avg_g_cos_lossavg_g_mag_lossavg_lossgrad_durationgrad_maegrad_mapegrad_rmsemean_pred_losspred_durationpred_maepred_mapepred_rmsepred_stdstd_losstotal_duration
lct_gan0.00.0004490.00474711.5821970.0906311.6706520.1432570.0000396.4151070.0472424645954.50.0688970.2363390.00085417.997304
\n", + "
" + ], + "text/plain": [ + " avg_g_cos_loss avg_g_mag_loss avg_loss grad_duration grad_mae \\\n", + "lct_gan 0.0 0.000449 0.004747 11.582197 0.090631 \n", + "\n", + " grad_mape grad_rmse mean_pred_loss pred_duration pred_mae \\\n", + "lct_gan 1.670652 0.143257 0.000039 6.415107 0.047242 \n", + "\n", + " pred_mape pred_rmse pred_std std_loss total_duration \n", + "lct_gan 4645954.5 0.068897 0.236339 0.000854 17.997304 " + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "\n", + "metrics = pd.DataFrame(eval_loss[\"model_metrics\"]).T\n", + "metrics.to_csv(\"eval.csv\")\n", + "metrics" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "id": "123d305b", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:05:25.690857Z", + "iopub.status.busy": "2024-03-22T19:05:25.690259Z", + "iopub.status.idle": "2024-03-22T19:05:26.216854Z", + "shell.execute_reply": "2024-03-22T19:05:26.215856Z" + }, + "papermill": { + "duration": 0.546735, + "end_time": "2024-03-22T19:05:26.219100", + "exception": false, + "start_time": "2024-03-22T19:05:25.672365", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "from ml_utility_loss.util import clear_memory\n", + "clear_memory()" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "a3eecc2a", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:05:26.257147Z", + "iopub.status.busy": "2024-03-22T19:05:26.256380Z", + "iopub.status.idle": "2024-03-22T19:10:48.212600Z", + "shell.execute_reply": "2024-03-22T19:10:48.211612Z" + }, + "papermill": { + "duration": 321.97834, + "end_time": "2024-03-22T19:10:48.215223", + "exception": false, + "start_time": "2024-03-22T19:05:26.236883", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Caching in ../../../../treatment/_cache_aug_test/lct_gan/all inf False\n", + "Caching in ../../../../treatment/_cache_bs_test/lct_gan/all inf False\n", + "Caching in ../../../../treatment/_cache_synth_test/lct_gan/all inf False\n" + ] + } + ], + "source": [ + "#\"\"\"\n", + "from ml_utility_loss.loss_learning.estimator.process import pred, pred_2\n", + "from ml_utility_loss.util import stack_samples\n", + "\n", + "#samples = test_set[list(range(len(test_set)))]\n", + "#y = {m: pred(model[m], s) for m, s in samples.items()}\n", + "y = pred_2(model, test_set, batch_size=batch_size)\n", + "#\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "6ab51db8", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:10:48.252524Z", + "iopub.status.busy": "2024-03-22T19:10:48.252211Z", + "iopub.status.idle": "2024-03-22T19:10:48.278190Z", + "shell.execute_reply": "2024-03-22T19:10:48.277470Z" + }, + "papermill": { + "duration": 0.046837, + "end_time": "2024-03-22T19:10:48.280187", + "exception": false, + "start_time": "2024-03-22T19:10:48.233350", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "from ml_utility_loss.util import transpose_dict\n", + "\n", + "os.makedirs(\"pred\", exist_ok=True)\n", + "y2 = transpose_dict(y)\n", + "for k, v in y2.items():\n", + " df = pd.DataFrame(v)\n", + " df.to_csv(f\"pred/{k}.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "d81a30f1", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:10:48.314828Z", + "iopub.status.busy": "2024-03-22T19:10:48.314536Z", + "iopub.status.idle": "2024-03-22T19:10:48.319916Z", + "shell.execute_reply": "2024-03-22T19:10:48.319062Z" + }, + "papermill": { + "duration": 0.025269, + "end_time": "2024-03-22T19:10:48.322225", + "exception": false, + "start_time": "2024-03-22T19:10:48.296956", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'lct_gan': 0.41919846044793607}\n" + ] + } + ], + "source": [ + "print({k: sum(v[\"pred\"])/len(v[\"pred\"]) for k, v in y.items()})" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "3b3ff322", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:10:48.360763Z", + "iopub.status.busy": "2024-03-22T19:10:48.359857Z", + "iopub.status.idle": "2024-03-22T19:10:48.794117Z", + "shell.execute_reply": "2024-03-22T19:10:48.793175Z" + }, + "papermill": { + "duration": 0.455983, + "end_time": "2024-03-22T19:10:48.796356", + "exception": false, + "start_time": "2024-03-22T19:10:48.340373", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from ml_utility_loss.loss_learning.visualization import plot_pred_density_2\n", + "\n", + "_ = plot_pred_density_2(y)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "e79e4b0f", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:10:48.834451Z", + "iopub.status.busy": "2024-03-22T19:10:48.834145Z", + "iopub.status.idle": "2024-03-22T19:10:49.238845Z", + "shell.execute_reply": "2024-03-22T19:10:49.237854Z" + }, + "papermill": { + "duration": 0.426045, + "end_time": "2024-03-22T19:10:49.240918", + "exception": false, + "start_time": "2024-03-22T19:10:48.814873", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from ml_utility_loss.loss_learning.visualization import plot_density_3\n", + "\n", + "_ = plot_density_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "745adde1", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:10:49.281017Z", + "iopub.status.busy": "2024-03-22T19:10:49.280683Z", + "iopub.status.idle": "2024-03-22T19:10:49.490180Z", + "shell.execute_reply": "2024-03-22T19:10:49.489281Z" + }, + "papermill": { + "duration": 0.232063, + "end_time": "2024-03-22T19:10:49.492171", + "exception": false, + "start_time": "2024-03-22T19:10:49.260108", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "from ml_utility_loss.loss_learning.visualization import plot_box_3\n", + "\n", + "_ = plot_box_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "eabe1bab", + "metadata": { + "execution": { + "iopub.execute_input": "2024-03-22T19:10:49.532570Z", + "iopub.status.busy": "2024-03-22T19:10:49.531730Z", + "iopub.status.idle": "2024-03-22T19:10:49.808277Z", + "shell.execute_reply": "2024-03-22T19:10:49.807332Z" + }, + "papermill": { + "duration": 0.2992, + "end_time": "2024-03-22T19:10:49.810292", + "exception": false, + "start_time": "2024-03-22T19:10:49.511092", + "status": "completed" + }, + "tags": [] + }, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "#\"\"\"\n", + "from ml_utility_loss.loss_learning.visualization import plot_grad, plot_grad_2, plot_grad_3\n", + "import matplotlib.pyplot as plt\n", + "\n", + "#plot_grad_2(y, model.models)\n", + "for m in model.models:\n", + " ym = y[m]\n", + " fig, ax = plt.subplots()\n", + " plot_grad_3(ym[\"error\"], ym[\"grad\"], name=f\"{m}_grad\", fig=fig, ax=ax)\n", + "#\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "54c0e9f3", + "metadata": { + "papermill": { + "duration": 0.019261, + "end_time": "2024-03-22T19:10:49.848721", + "exception": false, + "start_time": "2024-03-22T19:10:49.829460", + "status": "completed" + }, + "tags": [] + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "accelerator": "GPU", + "celltoolbar": "Tags", + "colab": { + "authorship_tag": "ABX9TyOOVfelovKP9fLGU7SvvRie", + "gpuType": "T4", + "mount_file_id": "17POSGAvge8y9DW9WGs2jLkibaRjToayg", + "provenance": [] + }, + "kaggle": { + "accelerator": "gpu", + "dataSources": [], + "dockerImageVersionId": 30648, + "isGpuEnabled": true, + "isInternetEnabled": true, + "language": "python", + "sourceType": "notebook" + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + }, + "papermill": { + "default_parameters": {}, + "duration": 8219.40624, + "end_time": "2024-03-22T19:10:52.592795", + "environment_variables": {}, + "exception": null, + "input_path": "eval/treatment/lct_gan/42/mlu-eval.ipynb", + "output_path": "eval/treatment/lct_gan/42/mlu-eval.ipynb", + "parameters": { + "allow_same_prediction": true, + "dataset": "treatment", + "dataset_name": "treatment", + "debug": false, + "folder": "eval", + "gp": true, + "gp_multiply": true, + "log_wandb": false, + "param_index": 0, + "path": "eval/treatment/lct_gan/42", + "path_prefix": "../../../../", + "random_seed": 42, + "single_model": "lct_gan" + }, + "start_time": "2024-03-22T16:53:53.186555", + "version": "2.5.0" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": false + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} \ No newline at end of file