File size: 241,849 Bytes
6950504 |
|
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "982e76f5",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:50.558870Z",
"iopub.status.busy": "2024-02-29T18:26:50.558154Z",
"iopub.status.idle": "2024-02-29T18:26:50.590106Z",
"shell.execute_reply": "2024-02-29T18:26:50.589458Z"
},
"papermill": {
"duration": 0.046189,
"end_time": "2024-02-29T18:26:50.591955",
"exception": false,
"start_time": "2024-02-29T18:26:50.545766",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import joblib\n",
"\n",
"#joblib.parallel_backend(\"threading\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "675f0b41",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:50.617349Z",
"iopub.status.busy": "2024-02-29T18:26:50.616718Z",
"iopub.status.idle": "2024-02-29T18:26:50.624173Z",
"shell.execute_reply": "2024-02-29T18:26:50.623371Z"
},
"papermill": {
"duration": 0.022172,
"end_time": "2024-02-29T18:26:50.626060",
"exception": false,
"start_time": "2024-02-29T18:26:50.603888",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'\\n%cd /kaggle/working\\n#!git clone https://github.com/R-N/ml-utility-loss\\n%cd ml-utility-loss\\n!git pull\\n#!pip install .\\n!pip install . --no-deps --force-reinstall --upgrade\\n#'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"\"\"\"\n",
"%cd /kaggle/working\n",
"#!git clone https://github.com/R-N/ml-utility-loss\n",
"%cd ml-utility-loss\n",
"!git pull\n",
"#!pip install .\n",
"!pip install . --no-deps --force-reinstall --upgrade\n",
"#\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "5ae30f5c",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:50.649180Z",
"iopub.status.busy": "2024-02-29T18:26:50.648911Z",
"iopub.status.idle": "2024-02-29T18:26:50.653522Z",
"shell.execute_reply": "2024-02-29T18:26:50.652094Z"
},
"papermill": {
"duration": 0.019003,
"end_time": "2024-02-29T18:26:50.655938",
"exception": false,
"start_time": "2024-02-29T18:26:50.636935",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt\n",
"plt.rcParams['figure.figsize'] = [3,3]"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9f42c810",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:50.679562Z",
"iopub.status.busy": "2024-02-29T18:26:50.679302Z",
"iopub.status.idle": "2024-02-29T18:26:50.683027Z",
"shell.execute_reply": "2024-02-29T18:26:50.682249Z"
},
"executionInfo": {
"elapsed": 678,
"status": "ok",
"timestamp": 1696841022168,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "ns5hFcVL2yvs",
"papermill": {
"duration": 0.01756,
"end_time": "2024-02-29T18:26:50.684861",
"exception": false,
"start_time": "2024-02-29T18:26:50.667301",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"datasets = [\n",
" \"insurance\",\n",
" \"treatment\",\n",
" \"contraceptive\"\n",
"]\n",
"\n",
"study_dir = \"./\""
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "85d0c8ce",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:50.707664Z",
"iopub.status.busy": "2024-02-29T18:26:50.707390Z",
"iopub.status.idle": "2024-02-29T18:26:50.712716Z",
"shell.execute_reply": "2024-02-29T18:26:50.711943Z"
},
"papermill": {
"duration": 0.018823,
"end_time": "2024-02-29T18:26:50.714583",
"exception": false,
"start_time": "2024-02-29T18:26:50.695760",
"status": "completed"
},
"tags": [
"parameters"
]
},
"outputs": [],
"source": [
"#Parameters\n",
"import os\n",
"\n",
"path_prefix = \"../../../../\"\n",
"\n",
"dataset_dir = os.path.join(path_prefix, \"ml-utility-loss/datasets\")\n",
"dataset_name = \"treatment\"\n",
"model_name=\"ml_utility_2\"\n",
"models = [\"tvae\", \"realtabformer\", \"lct_gan\", \"tab_ddpm_concat\"]\n",
"single_model = \"lct_gan\"\n",
"random_seed = 42\n",
"gp = True\n",
"gp_multiply = True\n",
"folder = \"eval\"\n",
"debug = False\n",
"path = None\n",
"param_index = 0\n",
"allow_same_prediction = False"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "1800468a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:50.739163Z",
"iopub.status.busy": "2024-02-29T18:26:50.738835Z",
"iopub.status.idle": "2024-02-29T18:26:50.744203Z",
"shell.execute_reply": "2024-02-29T18:26:50.743458Z"
},
"papermill": {
"duration": 0.019989,
"end_time": "2024-02-29T18:26:50.746255",
"exception": false,
"start_time": "2024-02-29T18:26:50.726266",
"status": "completed"
},
"tags": [
"injected-parameters"
]
},
"outputs": [],
"source": [
"# Parameters\n",
"dataset = \"insurance\"\n",
"dataset_name = \"insurance\"\n",
"single_model = \"tab_ddpm_concat\"\n",
"gp = False\n",
"gp_multiply = False\n",
"random_seed = 4\n",
"debug = False\n",
"folder = \"eval\"\n",
"path_prefix = \"../../../../\"\n",
"path = \"eval/insurance/tab_ddpm_concat/4\"\n",
"param_index = 2\n",
"allow_same_prediction = True\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd7c02d6",
"metadata": {
"papermill": {
"duration": 0.010925,
"end_time": "2024-02-29T18:26:50.768318",
"exception": false,
"start_time": "2024-02-29T18:26:50.757393",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 7,
"id": "5f45b1d0",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:50.791282Z",
"iopub.status.busy": "2024-02-29T18:26:50.791016Z",
"iopub.status.idle": "2024-02-29T18:26:50.799559Z",
"shell.execute_reply": "2024-02-29T18:26:50.798806Z"
},
"executionInfo": {
"elapsed": 7,
"status": "ok",
"timestamp": 1696841022169,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "UdvXYv3c3LXy",
"papermill": {
"duration": 0.022153,
"end_time": "2024-02-29T18:26:50.801373",
"exception": false,
"start_time": "2024-02-29T18:26:50.779220",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working\n",
"/kaggle/working/eval/insurance/tab_ddpm_concat/4\n"
]
}
],
"source": [
"from pathlib import Path\n",
"import os\n",
"\n",
"%cd /kaggle/working/\n",
"\n",
"if path is None:\n",
" path = os.path.join(folder, dataset_name, single_model, random_seed)\n",
"Path(path).mkdir(parents=True, exist_ok=True)\n",
"\n",
"%cd {path}"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "f85bf540",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:50.825158Z",
"iopub.status.busy": "2024-02-29T18:26:50.824897Z",
"iopub.status.idle": "2024-02-29T18:26:52.939311Z",
"shell.execute_reply": "2024-02-29T18:26:52.938380Z"
},
"papermill": {
"duration": 2.128694,
"end_time": "2024-02-29T18:26:52.941418",
"exception": false,
"start_time": "2024-02-29T18:26:50.812724",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Set seed to <function seed at 0x7a58ca31a320>\n"
]
}
],
"source": [
"from ml_utility_loss.util import seed\n",
"if single_model:\n",
" model_name=f\"{model_name}_{single_model}\"\n",
"if random_seed is not None:\n",
" seed(random_seed)\n",
" print(\"Set seed to\", seed)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "8489feae",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:52.968131Z",
"iopub.status.busy": "2024-02-29T18:26:52.967214Z",
"iopub.status.idle": "2024-02-29T18:26:52.979057Z",
"shell.execute_reply": "2024-02-29T18:26:52.978209Z"
},
"papermill": {
"duration": 0.027034,
"end_time": "2024-02-29T18:26:52.981011",
"exception": false,
"start_time": "2024-02-29T18:26:52.953977",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import json\n",
"import os\n",
"\n",
"df = pd.read_csv(os.path.join(dataset_dir, f\"{dataset_name}.csv\"))\n",
"with open(os.path.join(dataset_dir, f\"{dataset_name}.json\")) as f:\n",
" info = json.load(f)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "debcc684",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:53.005698Z",
"iopub.status.busy": "2024-02-29T18:26:53.005020Z",
"iopub.status.idle": "2024-02-29T18:26:53.011733Z",
"shell.execute_reply": "2024-02-29T18:26:53.010953Z"
},
"executionInfo": {
"elapsed": 6,
"status": "ok",
"timestamp": 1696841022169,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "Vrl2QkoV3o_8",
"papermill": {
"duration": 0.021252,
"end_time": "2024-02-29T18:26:53.013739",
"exception": false,
"start_time": "2024-02-29T18:26:52.992487",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"task = info[\"task\"]\n",
"target = info[\"target\"]\n",
"cat_features = info[\"cat_features\"]\n",
"mixed_features = info[\"mixed_features\"]\n",
"longtail_features = info[\"longtail_features\"]\n",
"integer_features = info[\"integer_features\"]\n",
"\n",
"test = df.sample(frac=0.2, random_state=42)\n",
"train = df[~df.index.isin(test.index)]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "7538184a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:53.037402Z",
"iopub.status.busy": "2024-02-29T18:26:53.037158Z",
"iopub.status.idle": "2024-02-29T18:26:53.136345Z",
"shell.execute_reply": "2024-02-29T18:26:53.135661Z"
},
"executionInfo": {
"elapsed": 6,
"status": "ok",
"timestamp": 1696841022169,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "TilUuFk9vqMb",
"papermill": {
"duration": 0.113075,
"end_time": "2024-02-29T18:26:53.138309",
"exception": false,
"start_time": "2024-02-29T18:26:53.025234",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import ml_utility_loss.synthesizers.tab_ddpm.params as TAB_DDPM_PARAMS\n",
"import ml_utility_loss.synthesizers.lct_gan.params as LCT_GAN_PARAMS\n",
"import ml_utility_loss.synthesizers.realtabformer.params as RTF_PARAMS\n",
"from ml_utility_loss.synthesizers.realtabformer.params.default import GPT2_PARAMS, REALTABFORMER_PARAMS\n",
"from ml_utility_loss.util import filter_dict_2, filter_dict\n",
"\n",
"tab_ddpm_params = getattr(TAB_DDPM_PARAMS, dataset_name).BEST\n",
"lct_gan_params = getattr(LCT_GAN_PARAMS, dataset_name).BEST\n",
"lct_ae_params = filter_dict_2(lct_gan_params, LCT_GAN_PARAMS.default.AE_PARAMS)\n",
"rtf_params = getattr(RTF_PARAMS, dataset_name).BEST\n",
"rtf_params = filter_dict(rtf_params, REALTABFORMER_PARAMS)\n",
"\n",
"lct_ae_embedding_size=lct_gan_params[\"embedding_size\"]\n",
"tab_ddpm_normalization=\"quantile\"\n",
"tab_ddpm_cat_encoding=tab_ddpm_params[\"cat_encoding\"]\n",
"#tab_ddpm_cat_encoding=\"one-hot\"\n",
"tab_ddpm_y_policy=\"default\"\n",
"tab_ddpm_is_y_cond=True"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "cca61838",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:53.163865Z",
"iopub.status.busy": "2024-02-29T18:26:53.163568Z",
"iopub.status.idle": "2024-02-29T18:26:57.713583Z",
"shell.execute_reply": "2024-02-29T18:26:57.712809Z"
},
"executionInfo": {
"elapsed": 3113,
"status": "ok",
"timestamp": 1696841025277,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "7Abt8nStvr9Z",
"papermill": {
"duration": 4.565534,
"end_time": "2024-02-29T18:26:57.715908",
"exception": false,
"start_time": "2024-02-29T18:26:53.150374",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2024-02-29 18:26:55.362648: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n",
"2024-02-29 18:26:55.362701: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n",
"2024-02-29 18:26:55.364282: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n"
]
}
],
"source": [
"from ml_utility_loss.loss_learning.estimator.pipeline import load_lct_ae\n",
"\n",
"lct_ae = load_lct_ae(\n",
" dataset_name=dataset_name,\n",
" model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n",
" model_name=\"lct_ae\",\n",
" df_name=\"df\",\n",
")\n",
"lct_ae = None"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "6f83b7b6",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:57.741084Z",
"iopub.status.busy": "2024-02-29T18:26:57.740440Z",
"iopub.status.idle": "2024-02-29T18:26:57.747224Z",
"shell.execute_reply": "2024-02-29T18:26:57.746357Z"
},
"papermill": {
"duration": 0.021021,
"end_time": "2024-02-29T18:26:57.749183",
"exception": false,
"start_time": "2024-02-29T18:26:57.728162",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from ml_utility_loss.loss_learning.estimator.pipeline import load_rtf_embed\n",
"\n",
"rtf_embed = load_rtf_embed(\n",
" dataset_name=dataset_name,\n",
" model_dir=os.path.join(path_prefix, \"ml-utility-loss/models\"),\n",
" model_name=\"realtabformer\",\n",
" df_name=\"df\",\n",
" ckpt_type=\"best-disc-model\"\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "0026de74",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:26:57.774557Z",
"iopub.status.busy": "2024-02-29T18:26:57.774285Z",
"iopub.status.idle": "2024-02-29T18:27:05.763239Z",
"shell.execute_reply": "2024-02-29T18:27:05.762279Z"
},
"executionInfo": {
"elapsed": 20137,
"status": "ok",
"timestamp": 1696841045408,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "tbaguWxAvtPi",
"papermill": {
"duration": 8.004862,
"end_time": "2024-02-29T18:27:05.766044",
"exception": false,
"start_time": "2024-02-29T18:26:57.761182",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n",
" warnings.warn(\n",
"/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n",
" warnings.warn(\n",
"/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:119: ConvergenceWarning: Number of distinct clusters (6) found smaller than n_clusters (10). Possibly due to duplicate points in X.\n",
" .fit(X)\n",
"/opt/conda/lib/python3.10/site-packages/sklearn/mixture/_base.py:274: ConvergenceWarning: Initialization 1 did not converge. Try different init parameters, or increase max_iter, tol or check for degenerate data.\n",
" warnings.warn(\n",
"100%|██████████| 1/1 [00:00<00:00, 2.95it/s]\n",
"/opt/conda/lib/python3.10/site-packages/sklearn/preprocessing/_encoders.py:868: FutureWarning: `sparse` was renamed to `sparse_output` in version 1.2 and will be removed in 1.4. `sparse_output` is ignored unless you leave `sparse` to its default value.\n",
" warnings.warn(\n"
]
}
],
"source": [
"from ml_utility_loss.loss_learning.estimator.preprocessing import DataPreprocessor\n",
"\n",
"preprocessor = DataPreprocessor(\n",
" task,\n",
" target=target,\n",
" cat_features=cat_features,\n",
" mixed_features=mixed_features,\n",
" longtail_features=longtail_features,\n",
" integer_features=integer_features,\n",
" lct_ae_embedding_size=lct_ae_embedding_size,\n",
" lct_ae_params=lct_ae_params,\n",
" lct_ae=lct_ae,\n",
" tab_ddpm_normalization=tab_ddpm_normalization,\n",
" tab_ddpm_cat_encoding=tab_ddpm_cat_encoding,\n",
" tab_ddpm_y_policy=tab_ddpm_y_policy,\n",
" tab_ddpm_is_y_cond=tab_ddpm_is_y_cond,\n",
" realtabformer_embedding=rtf_embed,\n",
" realtabformer_params=rtf_params,\n",
")\n",
"preprocessor.fit(df)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a9c9b110",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"execution": {
"iopub.execute_input": "2024-02-29T18:27:05.793951Z",
"iopub.status.busy": "2024-02-29T18:27:05.793616Z",
"iopub.status.idle": "2024-02-29T18:27:05.800096Z",
"shell.execute_reply": "2024-02-29T18:27:05.799321Z"
},
"executionInfo": {
"elapsed": 13,
"status": "ok",
"timestamp": 1696841045411,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "OxUH_GBEv2qK",
"outputId": "76464c90-3baf-4bdc-a955-6f4fddc16b9c",
"papermill": {
"duration": 0.022395,
"end_time": "2024-02-29T18:27:05.801957",
"exception": false,
"start_time": "2024-02-29T18:27:05.779562",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{'tvae': 36,\n",
" 'realtabformer': (19, 551, Embedding(551, 800), True),\n",
" 'lct_gan': 29,\n",
" 'tab_ddpm_concat': 12}"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preprocessor.adapter_sizes"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "3cb9ed90",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:27:05.827106Z",
"iopub.status.busy": "2024-02-29T18:27:05.826850Z",
"iopub.status.idle": "2024-02-29T18:27:05.831397Z",
"shell.execute_reply": "2024-02-29T18:27:05.830675Z"
},
"papermill": {
"duration": 0.019112,
"end_time": "2024-02-29T18:27:05.833194",
"exception": false,
"start_time": "2024-02-29T18:27:05.814082",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from ml_utility_loss.loss_learning.estimator.pipeline import load_dataset_3_factory\n",
"\n",
"datasetsn = load_dataset_3_factory(\n",
" dataset_dir=os.path.join(path_prefix, \"ml-utility-loss/\"),\n",
" dataset_name=dataset_name,\n",
" preprocessor=preprocessor,\n",
" cache_dir=path_prefix,\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "ad1eb833",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:27:05.858118Z",
"iopub.status.busy": "2024-02-29T18:27:05.857860Z",
"iopub.status.idle": "2024-02-29T18:27:05.888185Z",
"shell.execute_reply": "2024-02-29T18:27:05.887355Z"
},
"papermill": {
"duration": 0.044834,
"end_time": "2024-02-29T18:27:05.889997",
"exception": false,
"start_time": "2024-02-29T18:27:05.845163",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caching in ../../../../insurance/_cache_test/tab_ddpm_concat/all inf False\n"
]
}
],
"source": [
"from ml_utility_loss.loss_learning.estimator.pipeline import load_dataset\n",
"\n",
"test_set = load_dataset(\n",
" dataset_dir=os.path.join(path_prefix, \"ml-utility-loss/\", \"datasets_5\", dataset_name),\n",
" preprocessor=preprocessor,\n",
" cache_dir=os.path.join(path_prefix, dataset_name, \"_cache_test\"),\n",
" start=200,\n",
" #stop=600,\n",
" val=False,\n",
" ratio=0,\n",
" drop_first_column=True,\n",
" model=single_model,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "14ff8b40",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:27:05.917077Z",
"iopub.status.busy": "2024-02-29T18:27:05.916535Z",
"iopub.status.idle": "2024-02-29T18:27:06.229267Z",
"shell.execute_reply": "2024-02-29T18:27:06.228385Z"
},
"executionInfo": {
"elapsed": 588,
"status": "ok",
"timestamp": 1696841049215,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "NgahtU1q9uLO",
"papermill": {
"duration": 0.328598,
"end_time": "2024-02-29T18:27:06.231292",
"exception": false,
"start_time": "2024-02-29T18:27:05.902694",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"{'Body': 'twin_encoder',\n",
" 'loss_balancer_meta': True,\n",
" 'loss_balancer_log': False,\n",
" 'loss_balancer_lbtw': False,\n",
" 'pma_skip_small': False,\n",
" 'isab_skip_small': False,\n",
" 'layer_norm': False,\n",
" 'pma_layer_norm': False,\n",
" 'attn_residual': True,\n",
" 'tf_n_layers_dec': False,\n",
" 'tf_isab_rank': 0,\n",
" 'tf_layer_norm': False,\n",
" 'tf_pma_start': -1,\n",
" 'head_n_seeds': 0,\n",
" 'tf_pma_low': 16,\n",
" 'dropout': 0,\n",
" 'combine_mode': 'diff_left',\n",
" 'tf_isab_mode': 'separate',\n",
" 'grad_loss_fn': <function torch.nn.functional.l1_loss(input: torch.Tensor, target: torch.Tensor, size_average: Optional[bool] = None, reduce: Optional[bool] = None, reduction: str = 'mean') -> torch.Tensor>,\n",
" 'single_model': True,\n",
" 'bias': True,\n",
" 'bias_final': True,\n",
" 'pma_ffn_mode': 'shared',\n",
" 'patience': 10,\n",
" 'inds_init_mode': 'fixnorm',\n",
" 'grad_clip': 0.77,\n",
" 'head_final_mul': 'identity',\n",
" 'gradient_penalty_mode': {'gradient_penalty': False,\n",
" 'calc_grad_m': False,\n",
" 'avg_non_role_model_m': False,\n",
" 'inverse_avg_non_role_model_m': False},\n",
" 'synth_data': 2,\n",
" 'dataset_size': 2048,\n",
" 'batch_size': 8,\n",
" 'epochs': 100,\n",
" 'n_warmup_steps': 100,\n",
" 'Optim': torch_optimizer.diffgrad.DiffGrad,\n",
" 'loss_balancer_beta': 0.75,\n",
" 'loss_balancer_r': 0.95,\n",
" 'fixed_role_model': 'tab_ddpm_concat',\n",
" 'd_model': 256,\n",
" 'attn_activation': torch.nn.modules.activation.LeakyReLU,\n",
" 'tf_d_inner': 512,\n",
" 'tf_n_layers_enc': 4,\n",
" 'tf_n_head': 64,\n",
" 'tf_activation': torch.nn.modules.activation.ReLU6,\n",
" 'tf_activation_final': ml_utility_loss.activations.LeakyHardsigmoid,\n",
" 'ada_d_hid': 1024,\n",
" 'ada_n_layers': 7,\n",
" 'ada_activation': torch.nn.modules.activation.ReLU,\n",
" 'ada_activation_final': torch.nn.modules.activation.Softsign,\n",
" 'head_d_hid': 128,\n",
" 'head_n_layers': 9,\n",
" 'head_n_head': 64,\n",
" 'head_activation': torch.nn.modules.activation.RReLU,\n",
" 'head_activation_final': torch.nn.modules.activation.Softsign,\n",
" 'models': ['tab_ddpm_concat'],\n",
" 'max_seconds': 3600,\n",
" 'tf_lora': False,\n",
" 'tf_num_inds': 32,\n",
" 'ada_n_seeds': 0,\n",
" 'gradient_penalty_kwargs': {'mag_loss': True,\n",
" 'mse_mag': False,\n",
" 'mag_corr': False,\n",
" 'seq_mag': False,\n",
" 'cos_loss': False,\n",
" 'mag_corr_kwargs': {'only_sign': False},\n",
" 'cos_loss_kwargs': {'only_sign': True, 'cos_matrix': False},\n",
" 'mse_mag_kwargs': {'target': 0.1, 'multiply': False}}}"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import ml_utility_loss.loss_learning.estimator.params2 as PARAMS\n",
"from ml_utility_loss.tuning import map_parameters\n",
"from ml_utility_loss.loss_learning.estimator.params.default import update_param_space, update_param_space_2\n",
"import wandb\n",
"\n",
"#\"\"\"\n",
"param_space = {\n",
" **getattr(PARAMS, dataset_name).PARAM_SPACE,\n",
"}\n",
"params = {\n",
" **getattr(PARAMS, dataset_name).BESTS[param_index],\n",
"}\n",
"if gp:\n",
" params[\"gradient_penalty_mode\"] = \"ALL\"\n",
" params[\"mse_mag\"] = True\n",
" if gp_multiply:\n",
" params[\"mse_mag_multiply\"] = True\n",
" params[\"mse_mag_target\"] = 1.0\n",
" else:\n",
" params[\"mse_mag_multiply\"] = False\n",
" params[\"mse_mag_target\"] = 0.1\n",
"else:\n",
" params[\"gradient_penalty_mode\"] = \"NONE\"\n",
" params[\"mse_mag\"] = False\n",
"params[\"single_model\"] = False\n",
"if models:\n",
" params[\"models\"] = models\n",
"if single_model:\n",
" params[\"fixed_role_model\"] = single_model\n",
" params[\"single_model\"] = True\n",
" params[\"models\"] = [single_model]\n",
"if params[\"fixed_role_model\"] == \"realtabformer\" and dataset_name == \"treatment\":\n",
" params[\"batch_size\"] = 2\n",
"params[\"max_seconds\"] = 3600\n",
"params[\"patience\"] = 10\n",
"params[\"epochs\"] = 100\n",
"if debug:\n",
" params[\"epochs\"] = 2\n",
"with open(\"params.json\", \"w\") as f:\n",
" json.dump(params, f)\n",
"params = map_parameters(params, param_space=param_space)\n",
"params"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "a48bd9e9",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:27:06.257693Z",
"iopub.status.busy": "2024-02-29T18:27:06.257389Z",
"iopub.status.idle": "2024-02-29T18:27:06.330588Z",
"shell.execute_reply": "2024-02-29T18:27:06.329733Z"
},
"papermill": {
"duration": 0.088632,
"end_time": "2024-02-29T18:27:06.332688",
"exception": false,
"start_time": "2024-02-29T18:27:06.244056",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"load_dataset_3_factory 2\n",
"Caching in ../../../../insurance/_cache/tab_ddpm_concat/all inf False\n",
"Splitting without random!\n",
"Split with reverse index!\n",
"../../../../ml-utility-loss/datasets_2/insurance [80, 20]\n",
"Caching in ../../../../insurance/_cache4/tab_ddpm_concat/all inf False\n",
"Splitting without random!\n",
"Split with reverse index!\n",
"../../../../ml-utility-loss/datasets_4/insurance [80, 20]\n",
"Caching in ../../../../insurance/_cache5/tab_ddpm_concat/all inf False\n",
"Splitting without random!\n",
"Split with reverse index!\n",
"../../../../ml-utility-loss/datasets_5/insurance [160, 40]\n",
"[320, 80]\n",
"[320, 80]\n"
]
}
],
"source": [
"train_set, val_set = datasetsn(model=params[\"fixed_role_model\"], synth_data=params[\"synth_data\"])"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "2fcb1418",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 1000
},
"execution": {
"iopub.execute_input": "2024-02-29T18:27:06.361058Z",
"iopub.status.busy": "2024-02-29T18:27:06.360443Z",
"iopub.status.idle": "2024-02-29T18:27:06.783264Z",
"shell.execute_reply": "2024-02-29T18:27:06.782352Z"
},
"executionInfo": {
"elapsed": 396850,
"status": "error",
"timestamp": 1696841446059,
"user": {
"displayName": "Rizqi Nur",
"userId": "09644007964068789560"
},
"user_tz": -420
},
"id": "_bt1MQc5kpSk",
"outputId": "01c1d3e5-ac64-461d-835a-b76f4a66e6d6",
"papermill": {
"duration": 0.439025,
"end_time": "2024-02-29T18:27:06.785332",
"exception": false,
"start_time": "2024-02-29T18:27:06.346307",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Creating model of type <class 'ml_utility_loss.loss_learning.estimator.model.models.TwinEncoder'>\n",
"[*] Embedding False True\n",
"['tab_ddpm_concat'] 1\n"
]
}
],
"source": [
"from ml_utility_loss.loss_learning.estimator.model.pipeline import remove_non_model_params\n",
"from ml_utility_loss.loss_learning.estimator.pipeline import create_model\n",
"from ml_utility_loss.util import filter_dict, clear_memory\n",
"\n",
"clear_memory()\n",
"\n",
"params2 = remove_non_model_params(params)\n",
"adapters = filter_dict(preprocessor.adapter_sizes, params[\"models\"])\n",
"\n",
"model = create_model(\n",
" adapters=adapters,\n",
" #Body=\"twin_encoder\",\n",
" **params2,\n",
")\n",
"#cf.apply_weight_standardization(model, n_last_layers_ignore=0)\n",
"print(model.models, len(model.adapters))"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "938f94fc",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:27:06.814758Z",
"iopub.status.busy": "2024-02-29T18:27:06.813947Z",
"iopub.status.idle": "2024-02-29T18:27:06.818524Z",
"shell.execute_reply": "2024-02-29T18:27:06.817759Z"
},
"papermill": {
"duration": 0.021262,
"end_time": "2024-02-29T18:27:06.820462",
"exception": false,
"start_time": "2024-02-29T18:27:06.799200",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"study_name=f\"{model_name}_{dataset_name}\""
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "12fb613e",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:27:06.848255Z",
"iopub.status.busy": "2024-02-29T18:27:06.847981Z",
"iopub.status.idle": "2024-02-29T18:27:06.854747Z",
"shell.execute_reply": "2024-02-29T18:27:06.853941Z"
},
"papermill": {
"duration": 0.02231,
"end_time": "2024-02-29T18:27:06.856576",
"exception": false,
"start_time": "2024-02-29T18:27:06.834266",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"9613953"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def count_parameters(model):\n",
" return sum(p.numel() for p in model.parameters() if p.requires_grad)\n",
"\n",
"count_parameters(model)"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "bd386e57",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:27:06.882937Z",
"iopub.status.busy": "2024-02-29T18:27:06.882671Z",
"iopub.status.idle": "2024-02-29T18:27:06.968606Z",
"shell.execute_reply": "2024-02-29T18:27:06.967758Z"
},
"papermill": {
"duration": 0.101194,
"end_time": "2024-02-29T18:27:06.970459",
"exception": false,
"start_time": "2024-02-29T18:27:06.869265",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"========================================================================================================================\n",
"Layer (type:depth-idx) Output Shape Param #\n",
"========================================================================================================================\n",
"MLUtilitySingle [2, 1071, 12] --\n",
"├─Adapter: 1-1 [2, 1071, 12] --\n",
"│ └─Sequential: 2-1 [2, 1071, 256] --\n",
"│ │ └─FeedForward: 3-1 [2, 1071, 1024] --\n",
"│ │ │ └─Linear: 4-1 [2, 1071, 1024] 13,312\n",
"│ │ │ └─ReLU: 4-2 [2, 1071, 1024] --\n",
"│ │ └─FeedForward: 3-2 [2, 1071, 1024] --\n",
"│ │ │ └─Linear: 4-3 [2, 1071, 1024] 1,049,600\n",
"│ │ │ └─ReLU: 4-4 [2, 1071, 1024] --\n",
"│ │ └─FeedForward: 3-3 [2, 1071, 1024] --\n",
"│ │ │ └─Linear: 4-5 [2, 1071, 1024] 1,049,600\n",
"│ │ │ └─ReLU: 4-6 [2, 1071, 1024] --\n",
"│ │ └─FeedForward: 3-4 [2, 1071, 1024] --\n",
"│ │ │ └─Linear: 4-7 [2, 1071, 1024] 1,049,600\n",
"│ │ │ └─ReLU: 4-8 [2, 1071, 1024] --\n",
"│ │ └─FeedForward: 3-5 [2, 1071, 1024] --\n",
"│ │ │ └─Linear: 4-9 [2, 1071, 1024] 1,049,600\n",
"│ │ │ └─ReLU: 4-10 [2, 1071, 1024] --\n",
"│ │ └─FeedForward: 3-6 [2, 1071, 1024] --\n",
"│ │ │ └─Linear: 4-11 [2, 1071, 1024] 1,049,600\n",
"│ │ │ └─ReLU: 4-12 [2, 1071, 1024] --\n",
"│ │ └─FeedForward: 3-7 [2, 1071, 256] --\n",
"│ │ │ └─Linear: 4-13 [2, 1071, 256] 262,400\n",
"│ │ │ └─Softsign: 4-14 [2, 1071, 256] --\n",
"├─Adapter: 1-2 [2, 267, 12] (recursive)\n",
"│ └─Sequential: 2-2 [2, 267, 256] (recursive)\n",
"│ │ └─FeedForward: 3-8 [2, 267, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-15 [2, 267, 1024] (recursive)\n",
"│ │ │ └─ReLU: 4-16 [2, 267, 1024] --\n",
"│ │ └─FeedForward: 3-9 [2, 267, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-17 [2, 267, 1024] (recursive)\n",
"│ │ │ └─ReLU: 4-18 [2, 267, 1024] --\n",
"│ │ └─FeedForward: 3-10 [2, 267, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-19 [2, 267, 1024] (recursive)\n",
"│ │ │ └─ReLU: 4-20 [2, 267, 1024] --\n",
"│ │ └─FeedForward: 3-11 [2, 267, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-21 [2, 267, 1024] (recursive)\n",
"│ │ │ └─ReLU: 4-22 [2, 267, 1024] --\n",
"│ │ └─FeedForward: 3-12 [2, 267, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-23 [2, 267, 1024] (recursive)\n",
"│ │ │ └─ReLU: 4-24 [2, 267, 1024] --\n",
"│ │ └─FeedForward: 3-13 [2, 267, 1024] (recursive)\n",
"│ │ │ └─Linear: 4-25 [2, 267, 1024] (recursive)\n",
"│ │ │ └─ReLU: 4-26 [2, 267, 1024] --\n",
"│ │ └─FeedForward: 3-14 [2, 267, 256] (recursive)\n",
"│ │ │ └─Linear: 4-27 [2, 267, 256] (recursive)\n",
"│ │ │ └─Softsign: 4-28 [2, 267, 256] --\n",
"├─TwinEncoder: 1-3 [2, 4096] --\n",
"│ └─Encoder: 2-3 [2, 16, 256] --\n",
"│ │ └─ModuleList: 3-16 -- (recursive)\n",
"│ │ │ └─EncoderLayer: 4-29 [2, 1071, 256] --\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-1 [2, 1071, 256] --\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-1 [2, 32, 256] 8,192\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-2 [2, 32, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-1 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-2 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-3 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-4 [2, 64, 32, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-1 [2, 64, 32, 1071] --\n",
"│ │ │ │ │ │ └─Linear: 7-5 [2, 32, 256] 65,792\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-6 [2, 32, 256] --\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-3 [2, 1071, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-7 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-8 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-9 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-10 [2, 64, 1071, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-2 [2, 64, 1071, 32] --\n",
"│ │ │ │ │ │ └─Linear: 7-11 [2, 1071, 256] 65,792\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-12 [2, 1071, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-2 [2, 1071, 256] --\n",
"│ │ │ │ │ └─Linear: 6-4 [2, 1071, 512] 131,584\n",
"│ │ │ │ │ └─ReLU6: 6-5 [2, 1071, 512] --\n",
"│ │ │ │ │ └─Linear: 6-6 [2, 1071, 256] 131,328\n",
"│ │ │ └─EncoderLayer: 4-30 [2, 1071, 256] --\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-3 [2, 1071, 256] --\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-7 [2, 32, 256] 8,192\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-8 [2, 32, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-13 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-14 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-15 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-16 [2, 64, 32, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-3 [2, 64, 32, 1071] --\n",
"│ │ │ │ │ │ └─Linear: 7-17 [2, 32, 256] 65,792\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-18 [2, 32, 256] --\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-9 [2, 1071, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-19 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-20 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-21 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-22 [2, 64, 1071, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-4 [2, 64, 1071, 32] --\n",
"│ │ │ │ │ │ └─Linear: 7-23 [2, 1071, 256] 65,792\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-24 [2, 1071, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-4 [2, 1071, 256] --\n",
"│ │ │ │ │ └─Linear: 6-10 [2, 1071, 512] 131,584\n",
"│ │ │ │ │ └─ReLU6: 6-11 [2, 1071, 512] --\n",
"│ │ │ │ │ └─Linear: 6-12 [2, 1071, 256] 131,328\n",
"│ │ │ └─EncoderLayer: 4-31 [2, 1071, 256] --\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-5 [2, 1071, 256] --\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-13 [2, 32, 256] 8,192\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-14 [2, 32, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-25 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-26 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-27 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-28 [2, 64, 32, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-5 [2, 64, 32, 1071] --\n",
"│ │ │ │ │ │ └─Linear: 7-29 [2, 32, 256] 65,792\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-30 [2, 32, 256] --\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-15 [2, 1071, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-31 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-32 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-33 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-34 [2, 64, 1071, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-6 [2, 64, 1071, 32] --\n",
"│ │ │ │ │ │ └─Linear: 7-35 [2, 1071, 256] 65,792\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-36 [2, 1071, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-6 [2, 1071, 256] --\n",
"│ │ │ │ │ └─Linear: 6-16 [2, 1071, 512] 131,584\n",
"│ │ │ │ │ └─ReLU6: 6-17 [2, 1071, 512] --\n",
"│ │ │ │ │ └─Linear: 6-18 [2, 1071, 256] 131,328\n",
"│ │ │ └─EncoderLayer: 4-32 [2, 16, 256] --\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-7 [2, 1071, 256] --\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-19 [2, 32, 256] 8,192\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-20 [2, 32, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-37 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-38 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-39 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-40 [2, 64, 32, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-7 [2, 64, 32, 1071] --\n",
"│ │ │ │ │ │ └─Linear: 7-41 [2, 32, 256] 65,792\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-42 [2, 32, 256] --\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-21 [2, 1071, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-43 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-44 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-45 [2, 32, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-46 [2, 64, 1071, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-8 [2, 64, 1071, 32] --\n",
"│ │ │ │ │ │ └─Linear: 7-47 [2, 1071, 256] 65,792\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-48 [2, 1071, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-8 [2, 1071, 256] --\n",
"│ │ │ │ │ └─Linear: 6-22 [2, 1071, 512] 131,584\n",
"│ │ │ │ │ └─LeakyHardsigmoid: 6-23 [2, 1071, 512] --\n",
"│ │ │ │ │ └─Linear: 6-24 [2, 1071, 256] 131,328\n",
"│ │ │ │ └─PoolingByMultiheadAttention: 5-9 [2, 16, 256] --\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-25 [2, 16, 256] 4,096\n",
"│ │ │ │ │ └─SimpleMultiHeadAttention: 6-26 [2, 16, 256] --\n",
"│ │ │ │ │ │ └─Linear: 7-49 [2, 16, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-50 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─Linear: 7-51 [2, 1071, 256] 65,536\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-52 [2, 64, 16, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-9 [2, 64, 16, 1071] --\n",
"│ │ │ │ │ │ └─Linear: 7-53 [2, 16, 256] 65,792\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-54 [2, 16, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-10 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ └─Linear: 6-27 [2, 16, 512] (recursive)\n",
"│ │ │ │ │ └─LeakyHardsigmoid: 6-28 [2, 16, 512] --\n",
"│ │ │ │ │ └─Linear: 6-29 [2, 16, 256] (recursive)\n",
"│ └─Encoder: 2-4 [2, 16, 256] (recursive)\n",
"│ │ └─ModuleList: 3-16 -- (recursive)\n",
"│ │ │ └─EncoderLayer: 4-33 [2, 267, 256] (recursive)\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-11 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-30 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-31 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-55 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-56 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-57 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-58 [2, 64, 32, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-10 [2, 64, 32, 267] --\n",
"│ │ │ │ │ │ └─Linear: 7-59 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-60 [2, 32, 256] --\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-32 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-61 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-62 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-63 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-64 [2, 64, 267, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-11 [2, 64, 267, 32] --\n",
"│ │ │ │ │ │ └─Linear: 7-65 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-66 [2, 267, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-12 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ └─Linear: 6-33 [2, 267, 512] (recursive)\n",
"│ │ │ │ │ └─ReLU6: 6-34 [2, 267, 512] --\n",
"│ │ │ │ │ └─Linear: 6-35 [2, 267, 256] (recursive)\n",
"│ │ │ └─EncoderLayer: 4-34 [2, 267, 256] (recursive)\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-13 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-36 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-37 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-67 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-68 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-69 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-70 [2, 64, 32, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-12 [2, 64, 32, 267] --\n",
"│ │ │ │ │ │ └─Linear: 7-71 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-72 [2, 32, 256] --\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-38 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-73 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-74 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-75 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-76 [2, 64, 267, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-13 [2, 64, 267, 32] --\n",
"│ │ │ │ │ │ └─Linear: 7-77 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-78 [2, 267, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-14 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ └─Linear: 6-39 [2, 267, 512] (recursive)\n",
"│ │ │ │ │ └─ReLU6: 6-40 [2, 267, 512] --\n",
"│ │ │ │ │ └─Linear: 6-41 [2, 267, 256] (recursive)\n",
"│ │ │ └─EncoderLayer: 4-35 [2, 267, 256] (recursive)\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-15 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-42 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-43 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-79 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-80 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-81 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-82 [2, 64, 32, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-14 [2, 64, 32, 267] --\n",
"│ │ │ │ │ │ └─Linear: 7-83 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-84 [2, 32, 256] --\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-44 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-85 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-86 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-87 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-88 [2, 64, 267, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-15 [2, 64, 267, 32] --\n",
"│ │ │ │ │ │ └─Linear: 7-89 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-90 [2, 267, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-16 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ └─Linear: 6-45 [2, 267, 512] (recursive)\n",
"│ │ │ │ │ └─ReLU6: 6-46 [2, 267, 512] --\n",
"│ │ │ │ │ └─Linear: 6-47 [2, 267, 256] (recursive)\n",
"│ │ │ └─EncoderLayer: 4-36 [2, 16, 256] (recursive)\n",
"│ │ │ │ └─SimpleInducedSetAttention: 5-17 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-48 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-49 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-91 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-92 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-93 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-94 [2, 64, 32, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-16 [2, 64, 32, 267] --\n",
"│ │ │ │ │ │ └─Linear: 7-95 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-96 [2, 32, 256] --\n",
"│ │ │ │ │ └─MultiHeadAttention: 6-50 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-97 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-98 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-99 [2, 32, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-100 [2, 64, 267, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-17 [2, 64, 267, 32] --\n",
"│ │ │ │ │ │ └─Linear: 7-101 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-102 [2, 267, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-18 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ └─Linear: 6-51 [2, 267, 512] (recursive)\n",
"│ │ │ │ │ └─LeakyHardsigmoid: 6-52 [2, 267, 512] --\n",
"│ │ │ │ │ └─Linear: 6-53 [2, 267, 256] (recursive)\n",
"│ │ │ │ └─PoolingByMultiheadAttention: 5-19 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ └─TensorInductionPoint: 6-54 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ └─SimpleMultiHeadAttention: 6-55 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-103 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-104 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─Linear: 7-105 [2, 267, 256] (recursive)\n",
"│ │ │ │ │ │ └─ScaledDotProductAttention: 7-106 [2, 64, 16, 4] --\n",
"│ │ │ │ │ │ │ └─Softmax: 8-18 [2, 64, 16, 267] --\n",
"│ │ │ │ │ │ └─Linear: 7-107 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ │ └─LeakyReLU: 7-108 [2, 16, 256] --\n",
"│ │ │ │ └─DoubleFeedForward: 5-20 [2, 16, 256] (recursive)\n",
"│ │ │ │ │ └─Linear: 6-56 [2, 16, 512] (recursive)\n",
"│ │ │ │ │ └─LeakyHardsigmoid: 6-57 [2, 16, 512] --\n",
"│ │ │ │ │ └─Linear: 6-58 [2, 16, 256] (recursive)\n",
"├─Head: 1-4 [2] --\n",
"│ └─Sequential: 2-5 [2, 1] --\n",
"│ │ └─FeedForward: 3-17 [2, 128] --\n",
"│ │ │ └─Linear: 4-37 [2, 128] 524,416\n",
"│ │ │ └─RReLU: 4-38 [2, 128] --\n",
"│ │ └─FeedForward: 3-18 [2, 128] --\n",
"│ │ │ └─Linear: 4-39 [2, 128] 16,512\n",
"│ │ │ └─RReLU: 4-40 [2, 128] --\n",
"│ │ └─FeedForward: 3-19 [2, 128] --\n",
"│ │ │ └─Linear: 4-41 [2, 128] 16,512\n",
"│ │ │ └─RReLU: 4-42 [2, 128] --\n",
"│ │ └─FeedForward: 3-20 [2, 128] --\n",
"│ │ │ └─Linear: 4-43 [2, 128] 16,512\n",
"│ │ │ └─RReLU: 4-44 [2, 128] --\n",
"│ │ └─FeedForward: 3-21 [2, 128] --\n",
"│ │ │ └─Linear: 4-45 [2, 128] 16,512\n",
"│ │ │ └─RReLU: 4-46 [2, 128] --\n",
"│ │ └─FeedForward: 3-22 [2, 128] --\n",
"│ │ │ └─Linear: 4-47 [2, 128] 16,512\n",
"│ │ │ └─RReLU: 4-48 [2, 128] --\n",
"│ │ └─FeedForward: 3-23 [2, 128] --\n",
"│ │ │ └─Linear: 4-49 [2, 128] 16,512\n",
"│ │ │ └─RReLU: 4-50 [2, 128] --\n",
"│ │ └─FeedForward: 3-24 [2, 128] --\n",
"│ │ │ └─Linear: 4-51 [2, 128] 16,512\n",
"│ │ │ └─RReLU: 4-52 [2, 128] --\n",
"│ │ └─FeedForward: 3-25 [2, 1] --\n",
"│ │ │ └─Linear: 4-53 [2, 1] 129\n",
"│ │ │ └─Softsign: 4-54 [2, 1] --\n",
"========================================================================================================================\n",
"Total params: 9,613,953\n",
"Trainable params: 9,613,953\n",
"Non-trainable params: 0\n",
"Total mult-adds (M): 38.08\n",
"========================================================================================================================\n",
"Input size (MB): 0.13\n",
"Forward/backward pass size (MB): 307.47\n",
"Params size (MB): 38.46\n",
"Estimated Total Size (MB): 346.05\n",
"========================================================================================================================"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchinfo import summary\n",
"\n",
"role_model = params[\"fixed_role_model\"]\n",
"s = train_set[0][role_model]\n",
"summary(model[role_model], input_size=((2, *s[0].shape), (2, *s[1].shape)), depth=9) # 8 max"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "0f42c4d1",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:27:07.000684Z",
"iopub.status.busy": "2024-02-29T18:27:07.000395Z",
"iopub.status.idle": "2024-02-29T18:42:55.818307Z",
"shell.execute_reply": "2024-02-29T18:42:55.817254Z"
},
"papermill": {
"duration": 948.852675,
"end_time": "2024-02-29T18:42:55.837260",
"exception": false,
"start_time": "2024-02-29T18:27:06.984585",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: Tracking run with wandb version 0.16.3\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: W&B syncing is set to \u001b[1m`offline`\u001b[0m in this directory. \n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Run \u001b[1m`wandb online`\u001b[0m or set \u001b[1mWANDB_MODE=online\u001b[0m to enable cloud syncing.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"g_loss_mul 0.1\n",
"Epoch 0\n",
"Train loss {'avg_role_model_loss': 0.0265493107464863, 'avg_role_model_std_loss': 9.705848431753656, 'avg_role_model_mean_pred_loss': 0.0019637997826472465, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.0265493107464863, 'n_size': 320, 'n_batch': 40, 'duration': 39.08715486526489, 'duration_batch': 0.9771788716316223, 'duration_size': 0.12214735895395279, 'avg_pred_std': 0.04609664692543447}\n",
"Val loss {'avg_role_model_loss': 0.012864274116873275, 'avg_role_model_std_loss': 8.93672634124523, 'avg_role_model_mean_pred_loss': 3.463389237516879e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012864274116873275, 'n_size': 80, 'n_batch': 10, 'duration': 8.234524965286255, 'duration_batch': 0.8234524965286255, 'duration_size': 0.10293156206607819, 'avg_pred_std': 0.023089123656973243}\n",
"Epoch 1\n",
"Train loss {'avg_role_model_loss': 0.013430703204357996, 'avg_role_model_std_loss': 10.238072396071818, 'avg_role_model_mean_pred_loss': 0.0001760885078965657, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.013430703204357996, 'n_size': 320, 'n_batch': 40, 'duration': 38.923088788986206, 'duration_batch': 0.9730772197246551, 'duration_size': 0.12163465246558189, 'avg_pred_std': 0.027457697270438074}\n",
"Val loss {'avg_role_model_loss': 0.01386686596670188, 'avg_role_model_std_loss': 9.424022936335371, 'avg_role_model_mean_pred_loss': 5.71949209714262e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01386686596670188, 'n_size': 80, 'n_batch': 10, 'duration': 8.236119270324707, 'duration_batch': 0.8236119270324707, 'duration_size': 0.10295149087905883, 'avg_pred_std': 0.019944945629686118}\n",
"Epoch 2\n",
"Train loss {'avg_role_model_loss': 0.013098158335196786, 'avg_role_model_std_loss': 6.953670260656827, 'avg_role_model_mean_pred_loss': 7.627181049958409e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.013098158335196786, 'n_size': 320, 'n_batch': 40, 'duration': 38.896809816360474, 'duration_batch': 0.9724202454090118, 'duration_size': 0.12155253067612648, 'avg_pred_std': 0.03701225146651268}\n",
"Val loss {'avg_role_model_loss': 0.011231413613131735, 'avg_role_model_std_loss': 4.642900250397725, 'avg_role_model_mean_pred_loss': 1.232088975626766e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011231413613131735, 'n_size': 80, 'n_batch': 10, 'duration': 8.272239923477173, 'duration_batch': 0.8272239923477173, 'duration_size': 0.10340299904346466, 'avg_pred_std': 0.031016640178859235}\n",
"Epoch 3\n",
"Train loss {'avg_role_model_loss': 0.013012661421089432, 'avg_role_model_std_loss': 6.77741541211999, 'avg_role_model_mean_pred_loss': 0.00014677781123761946, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.013012661421089432, 'n_size': 320, 'n_batch': 40, 'duration': 39.03108096122742, 'duration_batch': 0.9757770240306854, 'duration_size': 0.12197212800383568, 'avg_pred_std': 0.040795679786242545}\n",
"Val loss {'avg_role_model_loss': 0.010680149483960122, 'avg_role_model_std_loss': 5.439762359634369, 'avg_role_model_mean_pred_loss': 8.51207419643174e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010680149483960122, 'n_size': 80, 'n_batch': 10, 'duration': 8.236795425415039, 'duration_batch': 0.8236795425415039, 'duration_size': 0.10295994281768799, 'avg_pred_std': 0.02782872337847948}\n",
"Epoch 4\n",
"Train loss {'avg_role_model_loss': 0.012592662169481628, 'avg_role_model_std_loss': 6.8064604322151805, 'avg_role_model_mean_pred_loss': 0.00012719917820476213, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012592662169481628, 'n_size': 320, 'n_batch': 40, 'duration': 38.966336727142334, 'duration_batch': 0.9741584181785583, 'duration_size': 0.12176980227231979, 'avg_pred_std': 0.03671876427251845}\n",
"Val loss {'avg_role_model_loss': 0.012881963208201341, 'avg_role_model_std_loss': 16.157494982505522, 'avg_role_model_mean_pred_loss': 0.00010115250418607502, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012881963208201341, 'n_size': 80, 'n_batch': 10, 'duration': 8.331452369689941, 'duration_batch': 0.8331452369689941, 'duration_size': 0.10414315462112426, 'avg_pred_std': 0.012491705431602895}\n",
"Epoch 5\n",
"Train loss {'avg_role_model_loss': 0.013670370759791694, 'avg_role_model_std_loss': 10.748200260194086, 'avg_role_model_mean_pred_loss': 0.0001568969438597634, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.013670370759791694, 'n_size': 320, 'n_batch': 40, 'duration': 38.94208788871765, 'duration_batch': 0.9735521972179413, 'duration_size': 0.12169402465224266, 'avg_pred_std': 0.029897483938839287}\n",
"Val loss {'avg_role_model_loss': 0.014085652580251917, 'avg_role_model_std_loss': 22.363185199221174, 'avg_role_model_mean_pred_loss': 0.00020219407759825003, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.014085652580251917, 'n_size': 80, 'n_batch': 10, 'duration': 8.2787184715271, 'duration_batch': 0.82787184715271, 'duration_size': 0.10348398089408875, 'avg_pred_std': 0.009641142934560776}\n",
"Epoch 6\n",
"Train loss {'avg_role_model_loss': 0.014017040852922946, 'avg_role_model_std_loss': 10.649183725507465, 'avg_role_model_mean_pred_loss': 0.00013577813718335108, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.014017040852922946, 'n_size': 320, 'n_batch': 40, 'duration': 38.94879508018494, 'duration_batch': 0.9737198770046234, 'duration_size': 0.12171498462557792, 'avg_pred_std': 0.028363983915187418}\n",
"Val loss {'avg_role_model_loss': 0.01068424858385697, 'avg_role_model_std_loss': 3.8434145080467714, 'avg_role_model_mean_pred_loss': 1.0552424407705985e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01068424858385697, 'n_size': 80, 'n_batch': 10, 'duration': 8.305310726165771, 'duration_batch': 0.8305310726165771, 'duration_size': 0.10381638407707214, 'avg_pred_std': 0.03533868733793497}\n",
"Epoch 7\n",
"Train loss {'avg_role_model_loss': 0.011766438081394881, 'avg_role_model_std_loss': 8.660977102358947, 'avg_role_model_mean_pred_loss': 8.090821406305792e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011766438081394881, 'n_size': 320, 'n_batch': 40, 'duration': 38.78416681289673, 'duration_batch': 0.9696041703224182, 'duration_size': 0.12120052129030227, 'avg_pred_std': 0.04158601735252887}\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.012133054883452132, 'avg_role_model_std_loss': 20.211999930033198, 'avg_role_model_mean_pred_loss': 2.2262640635517526e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012133054883452132, 'n_size': 80, 'n_batch': 10, 'duration': 8.369733810424805, 'duration_batch': 0.8369733810424804, 'duration_size': 0.10462167263031005, 'avg_pred_std': 0.010681234044022858}\n",
"Epoch 8\n",
"Train loss {'avg_role_model_loss': 0.012191647826693953, 'avg_role_model_std_loss': 7.005204355998285, 'avg_role_model_mean_pred_loss': 9.821474643096905e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012191647826693953, 'n_size': 320, 'n_batch': 40, 'duration': 38.88823890686035, 'duration_batch': 0.9722059726715088, 'duration_size': 0.1215257465839386, 'avg_pred_std': 0.03872000898700208}\n",
"Val loss {'avg_role_model_loss': 0.014966235030442476, 'avg_role_model_std_loss': 9.767283525761012, 'avg_role_model_mean_pred_loss': 0.0001517352883070089, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.014966235030442476, 'n_size': 80, 'n_batch': 10, 'duration': 8.22826075553894, 'duration_batch': 0.8228260755538941, 'duration_size': 0.10285325944423676, 'avg_pred_std': 0.01799462023191154}\n",
"Epoch 9\n",
"Train loss {'avg_role_model_loss': 0.012526353562134319, 'avg_role_model_std_loss': 6.590273188782885, 'avg_role_model_mean_pred_loss': 7.7691583878714e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012526353562134319, 'n_size': 320, 'n_batch': 40, 'duration': 38.93899869918823, 'duration_batch': 0.9734749674797059, 'duration_size': 0.12168437093496323, 'avg_pred_std': 0.03674360387958586}\n",
"Val loss {'avg_role_model_loss': 0.012331876624375581, 'avg_role_model_std_loss': 18.443907407086634, 'avg_role_model_mean_pred_loss': 6.805989072731223e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012331876624375581, 'n_size': 80, 'n_batch': 10, 'duration': 8.421772003173828, 'duration_batch': 0.8421772003173829, 'duration_size': 0.10527215003967286, 'avg_pred_std': 0.01039172657765448}\n",
"Epoch 10\n",
"Train loss {'avg_role_model_loss': 0.012064280622871593, 'avg_role_model_std_loss': 9.317451603279006, 'avg_role_model_mean_pred_loss': 3.690295125249321e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012064280622871593, 'n_size': 320, 'n_batch': 40, 'duration': 39.005112171173096, 'duration_batch': 0.9751278042793274, 'duration_size': 0.12189097553491593, 'avg_pred_std': 0.0359303968725726}\n",
"Val loss {'avg_role_model_loss': 0.01261272220290266, 'avg_role_model_std_loss': 10.194672084533522, 'avg_role_model_mean_pred_loss': 5.446935015456234e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01261272220290266, 'n_size': 80, 'n_batch': 10, 'duration': 8.333169937133789, 'duration_batch': 0.8333169937133789, 'duration_size': 0.10416462421417236, 'avg_pred_std': 0.01722581619396806}\n",
"Epoch 11\n",
"Train loss {'avg_role_model_loss': 0.012482693148194812, 'avg_role_model_std_loss': 8.178162423045615, 'avg_role_model_mean_pred_loss': 9.780007754767173e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012482693148194812, 'n_size': 320, 'n_batch': 40, 'duration': 38.96896147727966, 'duration_batch': 0.9742240369319916, 'duration_size': 0.12177800461649894, 'avg_pred_std': 0.03824995262548327}\n",
"Val loss {'avg_role_model_loss': 0.012514100689440966, 'avg_role_model_std_loss': 19.314230701327325, 'avg_role_model_mean_pred_loss': 7.543949816977147e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012514100689440966, 'n_size': 80, 'n_batch': 10, 'duration': 8.239241361618042, 'duration_batch': 0.8239241361618042, 'duration_size': 0.10299051702022552, 'avg_pred_std': 0.009454242698848248}\n",
"Epoch 12\n",
"Train loss {'avg_role_model_loss': 0.01332451379566919, 'avg_role_model_std_loss': 10.310542043212262, 'avg_role_model_mean_pred_loss': 0.0003665929893701819, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01332451379566919, 'n_size': 320, 'n_batch': 40, 'duration': 39.00809144973755, 'duration_batch': 0.9752022862434387, 'duration_size': 0.12190028578042984, 'avg_pred_std': 0.027350465022027492}\n",
"Val loss {'avg_role_model_loss': 0.010987071882118471, 'avg_role_model_std_loss': 4.729085849918556, 'avg_role_model_mean_pred_loss': 8.189743033426566e-06, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.010987071882118471, 'n_size': 80, 'n_batch': 10, 'duration': 8.261511325836182, 'duration_batch': 0.8261511325836182, 'duration_size': 0.10326889157295227, 'avg_pred_std': 0.03069485481828451}\n",
"Epoch 13\n",
"Train loss {'avg_role_model_loss': 0.013592794616124592, 'avg_role_model_std_loss': 7.457387926033698, 'avg_role_model_mean_pred_loss': 0.000220215535729551, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.013592794616124592, 'n_size': 320, 'n_batch': 40, 'duration': 38.93746519088745, 'duration_batch': 0.9734366297721863, 'duration_size': 0.12167957872152328, 'avg_pred_std': 0.03546805907972157}\n",
"Val loss {'avg_role_model_loss': 0.011548876191955059, 'avg_role_model_std_loss': 6.165951245542237, 'avg_role_model_mean_pred_loss': 1.5450504935188292e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011548876191955059, 'n_size': 80, 'n_batch': 10, 'duration': 8.323935985565186, 'duration_batch': 0.8323935985565185, 'duration_size': 0.10404919981956481, 'avg_pred_std': 0.026838560402393342}\n",
"Epoch 14\n",
"Train loss {'avg_role_model_loss': 0.013447031378746033, 'avg_role_model_std_loss': 8.19890535405798, 'avg_role_model_mean_pred_loss': 0.00016373289685844838, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.013447031378746033, 'n_size': 320, 'n_batch': 40, 'duration': 38.8496150970459, 'duration_batch': 0.9712403774261474, 'duration_size': 0.12140504717826843, 'avg_pred_std': 0.029747568373568355}\n",
"Val loss {'avg_role_model_loss': 0.011828925088047981, 'avg_role_model_std_loss': 5.351523938098455, 'avg_role_model_mean_pred_loss': 2.9337766557091526e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.011828925088047981, 'n_size': 80, 'n_batch': 10, 'duration': 8.288572311401367, 'duration_batch': 0.8288572311401368, 'duration_size': 0.1036071538925171, 'avg_pred_std': 0.0307698548771441}\n",
"Epoch 15\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train loss {'avg_role_model_loss': 0.01384369531297125, 'avg_role_model_std_loss': 8.665561918970889, 'avg_role_model_mean_pred_loss': 0.00016956335028766033, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01384369531297125, 'n_size': 320, 'n_batch': 40, 'duration': 38.970547676086426, 'duration_batch': 0.9742636919021607, 'duration_size': 0.12178296148777008, 'avg_pred_std': 0.03315324831055477}\n",
"Val loss {'avg_role_model_loss': 0.01150583740673028, 'avg_role_model_std_loss': 6.560287872780464, 'avg_role_model_mean_pred_loss': 1.5260961676233363e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.01150583740673028, 'n_size': 80, 'n_batch': 10, 'duration': 8.322679042816162, 'duration_batch': 0.8322679042816162, 'duration_size': 0.10403348803520203, 'avg_pred_std': 0.0259027692489326}\n",
"Epoch 16\n",
"Train loss {'avg_role_model_loss': 0.012172109389211982, 'avg_role_model_std_loss': 7.0008499470219245, 'avg_role_model_mean_pred_loss': 7.735751830111326e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012172109389211982, 'n_size': 320, 'n_batch': 40, 'duration': 39.07010316848755, 'duration_batch': 0.9767525792121887, 'duration_size': 0.12209407240152359, 'avg_pred_std': 0.03525363316293806}\n",
"Val loss {'avg_role_model_loss': 0.012191956081369425, 'avg_role_model_std_loss': 7.897130101547532, 'avg_role_model_mean_pred_loss': 2.3637969795231585e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012191956081369425, 'n_size': 80, 'n_batch': 10, 'duration': 8.288463592529297, 'duration_batch': 0.8288463592529297, 'duration_size': 0.10360579490661621, 'avg_pred_std': 0.022672764584422113}\n",
"Epoch 17\n",
"Train loss {'avg_role_model_loss': 0.012383807837613859, 'avg_role_model_std_loss': 4.774037581340053, 'avg_role_model_mean_pred_loss': 0.00011500121783720729, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012383807837613859, 'n_size': 320, 'n_batch': 40, 'duration': 38.94200682640076, 'duration_batch': 0.9735501706600189, 'duration_size': 0.12169377133250237, 'avg_pred_std': 0.04339534998871386}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: \n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Val loss {'avg_role_model_loss': 0.012275835702894256, 'avg_role_model_std_loss': 9.627914267603774, 'avg_role_model_mean_pred_loss': 4.899254280417153e-05, 'avg_role_model_g_mag_loss': 0.0, 'avg_role_model_g_cos_loss': 0.0, 'avg_non_role_model_g_mag_loss': 0.0, 'avg_non_role_model_g_cos_loss': 0.0, 'avg_non_role_model_embed_loss': 0.0, 'avg_loss': 0.012275835702894256, 'n_size': 80, 'n_batch': 10, 'duration': 8.323761701583862, 'duration_batch': 0.8323761701583863, 'duration_size': 0.10404702126979828, 'avg_pred_std': 0.018597377510741354}\n",
"Stopped False\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[34m\u001b[1mwandb\u001b[0m: \n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Run history:\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_test ▅▆▂▁▅▇▁▃█▄▄▄▂▂▃▂▃▄\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_train █▂▂▂▁▂▂▁▁▁▁▁▂▂▂▂▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_test ▅▄▇▆▂▁█▁▃▁▃▁▇▆▇▅▅▃\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_train █▁▅▆▄▂▁▆▅▅▄▅▁▄▂▃▄▇\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_test ▅▆▂▁▅▇▁▃█▄▄▄▂▂▃▂▃▄\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_train █▂▂▂▁▂▂▁▁▁▁▁▂▂▂▂▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_test ▂▃▁▁▄█▁▂▆▃▃▃▁▁▂▁▂▂\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_train █▂▁▁▁▁▁▁▁▁▁▁▂▂▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_test ▃▃▁▂▆█▁▇▃▇▃▇▁▂▂▂▃▃\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_train ▇▇▄▃▃██▆▄▃▆▅▇▄▅▆▄▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_test ▁▁▃▁▅▃▄▆▁█▅▁▂▄▃▄▃▄\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_train █▄▄▇▅▅▅▁▃▅▆▅▆▅▃▅█▅\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_test ▁▁▃▁▅▃▄▆▁█▅▁▂▄▃▄▃▄\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_train █▄▄▇▅▅▅▁▃▅▆▅▆▅▃▅█▅\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_test ▁▁▃▁▅▃▄▆▁█▅▁▂▄▃▄▃▄\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_train █▄▄▇▅▅▅▁▃▅▆▅▆▅▃▅█▅\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: n_size_test ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: n_size_train ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: \n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Run summary:\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_test 0.01228\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_loss_train 0.01238\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_test 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_embed_loss_train 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_test 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_cos_loss_train 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_test 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_non_role_model_g_mag_loss_train 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_test 0.0186\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_pred_std_train 0.0434\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_test 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_cos_loss_train 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_test 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_g_mag_loss_train 0.0\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_test 0.01228\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_loss_train 0.01238\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_test 5e-05\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_mean_pred_loss_train 0.00012\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_test 9.62791\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: avg_role_model_std_loss_train 4.77404\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_test 0.83238\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_batch_train 0.97355\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_test 0.10405\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_size_train 0.12169\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_test 8.32376\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: duration_train 38.94201\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_test 10\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: n_batch_train 40\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: n_size_test 80\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: n_size_train 320\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: \n",
"\u001b[34m\u001b[1mwandb\u001b[0m: You can sync this run to the cloud by running:\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[1mwandb sync /kaggle/working/eval/insurance/tab_ddpm_concat/4/wandb/offline-run-20240229_182708-etiddam5\u001b[0m\n",
"\u001b[34m\u001b[1mwandb\u001b[0m: Find logs at: \u001b[35m\u001b[1m./wandb/offline-run-20240229_182708-etiddam5/logs\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Eval loss {'role_model': 'tab_ddpm_concat', 'n_size': 399, 'n_batch': 50, 'role_model_metrics': {'avg_loss': 0.01993635216760531, 'avg_g_mag_loss': 0.6526490935406888, 'avg_g_cos_loss': 4.4486220484027385e-08, 'pred_duration': 0.8737220764160156, 'grad_duration': 0.5640714168548584, 'total_duration': 1.437793493270874, 'pred_std': 0.053181204944849014, 'std_loss': 0.7849577069282532, 'mean_pred_loss': 1.8548063962953165e-05, 'pred_rmse': 0.141196146607399, 'pred_mae': 0.0972040668129921, 'pred_mape': 0.7692168354988098, 'grad_rmse': 0.28230687975883484, 'grad_mae': 0.19419056177139282, 'grad_mape': 0.9970712065696716}, 'non_role_model_metrics': {'avg_loss': 0, 'avg_g_mag_loss': 0, 'avg_g_cos_loss': 0, 'avg_pred_duration': 0, 'avg_grad_duration': 0, 'avg_total_duration': 0, 'avg_pred_std': 0, 'avg_std_loss': 0, 'avg_mean_pred_loss': 0}, 'avg_metrics': {'avg_loss': 0.01993635216760531, 'avg_g_mag_loss': 0.6526490935406888, 'avg_g_cos_loss': 4.4486220484027385e-08, 'avg_pred_duration': 0.8737220764160156, 'avg_grad_duration': 0.5640714168548584, 'avg_total_duration': 1.437793493270874, 'avg_pred_std': 0.053181204944849014, 'avg_std_loss': 0.7849577069282532, 'avg_mean_pred_loss': 1.8548063962953165e-05}, 'min_metrics': {'avg_loss': 0.01993635216760531, 'avg_g_mag_loss': 0.6526490935406888, 'avg_g_cos_loss': 4.4486220484027385e-08, 'pred_duration': 0.8737220764160156, 'grad_duration': 0.5640714168548584, 'total_duration': 1.437793493270874, 'pred_std': 0.053181204944849014, 'std_loss': 0.7849577069282532, 'mean_pred_loss': 1.8548063962953165e-05, 'pred_rmse': 0.141196146607399, 'pred_mae': 0.0972040668129921, 'pred_mape': 0.7692168354988098, 'grad_rmse': 0.28230687975883484, 'grad_mae': 0.19419056177139282, 'grad_mape': 0.9970712065696716}, 'model_metrics': {'tab_ddpm_concat': {'avg_loss': 0.01993635216760531, 'avg_g_mag_loss': 0.6526490935406888, 'avg_g_cos_loss': 4.4486220484027385e-08, 'pred_duration': 0.8737220764160156, 'grad_duration': 0.5640714168548584, 'total_duration': 1.437793493270874, 'pred_std': 0.053181204944849014, 'std_loss': 0.7849577069282532, 'mean_pred_loss': 1.8548063962953165e-05, 'pred_rmse': 0.141196146607399, 'pred_mae': 0.0972040668129921, 'pred_mape': 0.7692168354988098, 'grad_rmse': 0.28230687975883484, 'grad_mae': 0.19419056177139282, 'grad_mape': 0.9970712065696716}}}\n"
]
}
],
"source": [
"import torch\n",
"from ml_utility_loss.loss_learning.estimator.pipeline import train, train_2\n",
"from ml_utility_loss.loss_learning.estimator.process_simple import train_epoch, eval as _eval\n",
"from ml_utility_loss.params import GradientPenaltyMode\n",
"from ml_utility_loss.util import clear_memory\n",
"import time\n",
"#torch.autograd.set_detect_anomaly(True)\n",
"\n",
"clear_memory()\n",
"\n",
"opt = params[\"Optim\"](model.parameters())\n",
"loss = train_2(\n",
" [train_set, val_set, test_set],\n",
" preprocessor=preprocessor,\n",
" whole_model=model,\n",
" optim=opt,\n",
" log_dir=\"logs\",\n",
" checkpoint_dir=\"checkpoints\",\n",
" verbose=True,\n",
" allow_same_prediction=allow_same_prediction,\n",
" wandb=wandb,\n",
" study_name=study_name,\n",
" **params\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "9b514a07",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:42:55.875215Z",
"iopub.status.busy": "2024-02-29T18:42:55.874907Z",
"iopub.status.idle": "2024-02-29T18:42:55.879007Z",
"shell.execute_reply": "2024-02-29T18:42:55.878119Z"
},
"papermill": {
"duration": 0.025079,
"end_time": "2024-02-29T18:42:55.880936",
"exception": false,
"start_time": "2024-02-29T18:42:55.855857",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"model = loss[\"whole_model\"]\n",
"opt = loss[\"optim\"]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "331a49e1",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:42:55.915833Z",
"iopub.status.busy": "2024-02-29T18:42:55.915549Z",
"iopub.status.idle": "2024-02-29T18:42:55.991917Z",
"shell.execute_reply": "2024-02-29T18:42:55.990964Z"
},
"papermill": {
"duration": 0.096139,
"end_time": "2024-02-29T18:42:55.994052",
"exception": false,
"start_time": "2024-02-29T18:42:55.897913",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"from copy import deepcopy\n",
"\n",
"torch.save(deepcopy(model.state_dict()), \"model.pt\")\n",
"#torch.save(deepcopy(opt.state_dict()), \"optim.pt\")"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "123b4b17",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:42:56.032185Z",
"iopub.status.busy": "2024-02-29T18:42:56.031896Z",
"iopub.status.idle": "2024-02-29T18:42:56.328976Z",
"shell.execute_reply": "2024-02-29T18:42:56.328057Z"
},
"papermill": {
"duration": 0.319257,
"end_time": "2024-02-29T18:42:56.331162",
"exception": false,
"start_time": "2024-02-29T18:42:56.011905",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"<Axes: >"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"history = loss[\"history\"]\n",
"history.to_csv(\"history.csv\")\n",
"history[[\"avg_loss_train\", \"avg_loss_test\"]].plot()"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "2586ba0a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:42:56.370376Z",
"iopub.status.busy": "2024-02-29T18:42:56.370066Z",
"iopub.status.idle": "2024-02-29T18:43:43.371513Z",
"shell.execute_reply": "2024-02-29T18:43:43.370520Z"
},
"papermill": {
"duration": 47.023831,
"end_time": "2024-02-29T18:43:43.374040",
"exception": false,
"start_time": "2024-02-29T18:42:56.350209",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"\n",
"from ml_utility_loss.loss_learning.estimator.pipeline import eval\n",
"#eval_loss = loss[\"eval_loss\"]\n",
"\n",
"batch_size = params[\"batch_size_low\"] if \"batch_size_low\" in params else params[\"batch_size\"]\n",
"\n",
"eval_loss = eval(\n",
" test_set, model,\n",
" batch_size=batch_size,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "187137f6",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:43:43.414397Z",
"iopub.status.busy": "2024-02-29T18:43:43.413556Z",
"iopub.status.idle": "2024-02-29T18:43:43.433770Z",
"shell.execute_reply": "2024-02-29T18:43:43.432945Z"
},
"papermill": {
"duration": 0.042663,
"end_time": "2024-02-29T18:43:43.435796",
"exception": false,
"start_time": "2024-02-29T18:43:43.393133",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>avg_g_cos_loss</th>\n",
" <th>avg_g_mag_loss</th>\n",
" <th>avg_loss</th>\n",
" <th>grad_duration</th>\n",
" <th>grad_mae</th>\n",
" <th>grad_mape</th>\n",
" <th>grad_rmse</th>\n",
" <th>mean_pred_loss</th>\n",
" <th>pred_duration</th>\n",
" <th>pred_mae</th>\n",
" <th>pred_mape</th>\n",
" <th>pred_rmse</th>\n",
" <th>pred_std</th>\n",
" <th>std_loss</th>\n",
" <th>total_duration</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>tab_ddpm_concat</th>\n",
" <td>5.952382e-08</td>\n",
" <td>0.609263</td>\n",
" <td>0.019936</td>\n",
" <td>0.559148</td>\n",
" <td>0.194191</td>\n",
" <td>0.997071</td>\n",
" <td>0.282307</td>\n",
" <td>0.000019</td>\n",
" <td>0.876673</td>\n",
" <td>0.097204</td>\n",
" <td>0.769236</td>\n",
" <td>0.141196</td>\n",
" <td>0.053181</td>\n",
" <td>0.784962</td>\n",
" <td>1.435821</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" avg_g_cos_loss avg_g_mag_loss avg_loss grad_duration \\\n",
"tab_ddpm_concat 5.952382e-08 0.609263 0.019936 0.559148 \n",
"\n",
" grad_mae grad_mape grad_rmse mean_pred_loss \\\n",
"tab_ddpm_concat 0.194191 0.997071 0.282307 0.000019 \n",
"\n",
" pred_duration pred_mae pred_mape pred_rmse pred_std \\\n",
"tab_ddpm_concat 0.876673 0.097204 0.769236 0.141196 0.053181 \n",
"\n",
" std_loss total_duration \n",
"tab_ddpm_concat 0.784962 1.435821 "
]
},
"execution_count": 29,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"metrics = pd.DataFrame(eval_loss[\"model_metrics\"]).T\n",
"metrics.to_csv(\"eval.csv\")\n",
"metrics"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "123d305b",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:43:43.475408Z",
"iopub.status.busy": "2024-02-29T18:43:43.474824Z",
"iopub.status.idle": "2024-02-29T18:43:43.909232Z",
"shell.execute_reply": "2024-02-29T18:43:43.908268Z"
},
"papermill": {
"duration": 0.456458,
"end_time": "2024-02-29T18:43:43.911370",
"exception": false,
"start_time": "2024-02-29T18:43:43.454912",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"from ml_utility_loss.util import clear_memory\n",
"clear_memory()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "a3eecc2a",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:43:43.951744Z",
"iopub.status.busy": "2024-02-29T18:43:43.950961Z",
"iopub.status.idle": "2024-02-29T18:44:32.670044Z",
"shell.execute_reply": "2024-02-29T18:44:32.669005Z"
},
"papermill": {
"duration": 48.741699,
"end_time": "2024-02-29T18:44:32.672417",
"exception": false,
"start_time": "2024-02-29T18:43:43.930718",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Caching in ../../../../insurance/_cache_test/tab_ddpm_concat/all inf False\n"
]
}
],
"source": [
"#\"\"\"\n",
"from ml_utility_loss.loss_learning.estimator.process import pred, pred_2\n",
"from ml_utility_loss.util import stack_samples\n",
"\n",
"#samples = test_set[list(range(len(test_set)))]\n",
"#y = {m: pred(model[m], s) for m, s in samples.items()}\n",
"y = pred_2(model, test_set, batch_size=batch_size)\n",
"#\"\"\""
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "6ab51db8",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:44:32.713356Z",
"iopub.status.busy": "2024-02-29T18:44:32.713056Z",
"iopub.status.idle": "2024-02-29T18:44:32.730203Z",
"shell.execute_reply": "2024-02-29T18:44:32.729517Z"
},
"papermill": {
"duration": 0.038975,
"end_time": "2024-02-29T18:44:32.732107",
"exception": false,
"start_time": "2024-02-29T18:44:32.693132",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": [
"import os\n",
"import pandas as pd\n",
"from ml_utility_loss.util import transpose_dict\n",
"\n",
"os.makedirs(\"pred\", exist_ok=True)\n",
"y2 = transpose_dict(y)\n",
"for k, v in y2.items():\n",
" df = pd.DataFrame(v)\n",
" df.to_csv(f\"pred/{k}.csv\")"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "d81a30f1",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:44:32.769055Z",
"iopub.status.busy": "2024-02-29T18:44:32.768530Z",
"iopub.status.idle": "2024-02-29T18:44:32.773631Z",
"shell.execute_reply": "2024-02-29T18:44:32.772799Z"
},
"papermill": {
"duration": 0.025757,
"end_time": "2024-02-29T18:44:32.775758",
"exception": false,
"start_time": "2024-02-29T18:44:32.750001",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'tab_ddpm_concat': 0.05795487974159697}\n"
]
}
],
"source": [
"print({k: sum(v[\"pred\"])/len(v[\"pred\"]) for k, v in y.items()})"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "3b3ff322",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:44:32.813941Z",
"iopub.status.busy": "2024-02-29T18:44:32.813678Z",
"iopub.status.idle": "2024-02-29T18:44:33.119654Z",
"shell.execute_reply": "2024-02-29T18:44:33.118662Z"
},
"papermill": {
"duration": 0.327687,
"end_time": "2024-02-29T18:44:33.121766",
"exception": false,
"start_time": "2024-02-29T18:44:32.794079",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from ml_utility_loss.loss_learning.visualization import plot_pred_density_2\n",
"\n",
"_ = plot_pred_density_2(y)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "e79e4b0f",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:44:33.160005Z",
"iopub.status.busy": "2024-02-29T18:44:33.159736Z",
"iopub.status.idle": "2024-02-29T18:44:33.460975Z",
"shell.execute_reply": "2024-02-29T18:44:33.460154Z"
},
"papermill": {
"duration": 0.322423,
"end_time": "2024-02-29T18:44:33.462943",
"exception": false,
"start_time": "2024-02-29T18:44:33.140520",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from ml_utility_loss.loss_learning.visualization import plot_density_3\n",
"\n",
"_ = plot_density_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "745adde1",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:44:33.503100Z",
"iopub.status.busy": "2024-02-29T18:44:33.502818Z",
"iopub.status.idle": "2024-02-29T18:44:33.729512Z",
"shell.execute_reply": "2024-02-29T18:44:33.728656Z"
},
"papermill": {
"duration": 0.248914,
"end_time": "2024-02-29T18:44:33.731368",
"exception": false,
"start_time": "2024-02-29T18:44:33.482454",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from ml_utility_loss.loss_learning.visualization import plot_box_3\n",
"\n",
"_ = plot_box_3(y2[\"pred\"], next(iter(y2[\"y\"].values())))"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "eabe1bab",
"metadata": {
"execution": {
"iopub.execute_input": "2024-02-29T18:44:33.772296Z",
"iopub.status.busy": "2024-02-29T18:44:33.772018Z",
"iopub.status.idle": "2024-02-29T18:44:34.042778Z",
"shell.execute_reply": "2024-02-29T18:44:34.041888Z"
},
"papermill": {
"duration": 0.293553,
"end_time": "2024-02-29T18:44:34.044775",
"exception": false,
"start_time": "2024-02-29T18:44:33.751222",
"status": "completed"
},
"tags": []
},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 300x300 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"#\"\"\"\n",
"from ml_utility_loss.loss_learning.visualization import plot_grad, plot_grad_2, plot_grad_3\n",
"import matplotlib.pyplot as plt\n",
"\n",
"#plot_grad_2(y, model.models)\n",
"for m in model.models:\n",
" ym = y[m]\n",
" fig, ax = plt.subplots()\n",
" plot_grad_3(ym[\"error\"], ym[\"grad\"], name=f\"{m}_grad\", fig=fig, ax=ax)\n",
"#\"\"\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "54c0e9f3",
"metadata": {
"papermill": {
"duration": 0.019946,
"end_time": "2024-02-29T18:44:34.084886",
"exception": false,
"start_time": "2024-02-29T18:44:34.064940",
"status": "completed"
},
"tags": []
},
"outputs": [],
"source": []
}
],
"metadata": {
"accelerator": "GPU",
"celltoolbar": "Tags",
"colab": {
"authorship_tag": "ABX9TyOOVfelovKP9fLGU7SvvRie",
"gpuType": "T4",
"mount_file_id": "17POSGAvge8y9DW9WGs2jLkibaRjToayg",
"provenance": []
},
"kaggle": {
"accelerator": "gpu",
"dataSources": [],
"dockerImageVersionId": 30648,
"isGpuEnabled": true,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook"
},
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.0"
},
"papermill": {
"default_parameters": {},
"duration": 1067.648958,
"end_time": "2024-02-29T18:44:36.826930",
"environment_variables": {},
"exception": null,
"input_path": "eval/insurance/tab_ddpm_concat/4/mlu-eval.ipynb",
"output_path": "eval/insurance/tab_ddpm_concat/4/mlu-eval.ipynb",
"parameters": {
"allow_same_prediction": true,
"dataset": "insurance",
"dataset_name": "insurance",
"debug": false,
"folder": "eval",
"gp": false,
"gp_multiply": false,
"param_index": 2,
"path": "eval/insurance/tab_ddpm_concat/4",
"path_prefix": "../../../../",
"random_seed": 4,
"single_model": "tab_ddpm_concat"
},
"start_time": "2024-02-29T18:26:49.177972",
"version": "2.5.0"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": false,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 5
}
|