rizqinur2010 commited on
Commit
f7a584f
·
1 Parent(s): e09e808
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. contraceptive/lct_gan/eval.csv +2 -0
  2. contraceptive/lct_gan/history.csv +11 -0
  3. contraceptive/lct_gan/mlu-eval.ipynb +0 -0
  4. contraceptive/lct_gan/model.pt +3 -0
  5. contraceptive/lct_gan/params.json +1 -0
  6. contraceptive/realtabformer/eval.csv +2 -0
  7. contraceptive/realtabformer/history.csv +10 -0
  8. contraceptive/realtabformer/mlu-eval.ipynb +0 -0
  9. contraceptive/realtabformer/model.pt +3 -0
  10. contraceptive/realtabformer/params.json +1 -0
  11. contraceptive/tab_ddpm_concat/eval.csv +2 -0
  12. contraceptive/tab_ddpm_concat/history.csv +11 -0
  13. contraceptive/tab_ddpm_concat/mlu-eval.ipynb +0 -0
  14. contraceptive/tab_ddpm_concat/model.pt +3 -0
  15. contraceptive/tab_ddpm_concat/params.json +1 -0
  16. contraceptive/tvae/eval.csv +2 -0
  17. contraceptive/tvae/history.csv +11 -0
  18. contraceptive/tvae/mlu-eval.ipynb +0 -0
  19. contraceptive/tvae/model.pt +3 -0
  20. contraceptive/tvae/params.json +1 -0
  21. insurance/lct_gan/eval.csv +2 -0
  22. insurance/lct_gan/history.csv +18 -0
  23. insurance/lct_gan/mlu-eval.ipynb +0 -0
  24. insurance/lct_gan/model.pt +3 -0
  25. insurance/lct_gan/params.json +1 -0
  26. insurance/realtabformer/eval.csv +2 -0
  27. insurance/realtabformer/history.csv +17 -0
  28. insurance/realtabformer/mlu-eval.ipynb +0 -0
  29. insurance/realtabformer/model.pt +3 -0
  30. insurance/realtabformer/params.json +1 -0
  31. insurance/tab_ddpm_concat/eval.csv +2 -0
  32. insurance/tab_ddpm_concat/history.csv +17 -0
  33. insurance/tab_ddpm_concat/mlu-eval.ipynb +0 -0
  34. insurance/tab_ddpm_concat/model.pt +3 -0
  35. insurance/tab_ddpm_concat/params.json +1 -0
  36. insurance/tvae/eval.csv +2 -0
  37. insurance/tvae/history.csv +13 -0
  38. insurance/tvae/mlu-eval.ipynb +0 -0
  39. insurance/tvae/model.pt +3 -0
  40. insurance/tvae/params.json +1 -0
  41. treatment/lct_gan/eval.csv +2 -0
  42. treatment/lct_gan/history.csv +8 -0
  43. treatment/lct_gan/mlu-eval.ipynb +0 -0
  44. treatment/lct_gan/model.pt +3 -0
  45. treatment/lct_gan/params.json +1 -0
  46. treatment/realtabformer/eval.csv +2 -0
  47. treatment/realtabformer/history.csv +5 -0
  48. treatment/realtabformer/mlu-eval.ipynb +0 -0
  49. treatment/realtabformer/model.pt +3 -0
  50. treatment/realtabformer/params.json +1 -0
contraceptive/lct_gan/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ lct_gan,0.005347351378184699,0.08097088616235722,0.002836061054452633,12.46730089187622,0.03379097953438759,0.8471567034721375,0.1400071233510971,9.070246051123831e-06,4.109842538833618,0.04161286726593971,0.12254983186721802,0.05325467884540558,0.1082986444234848,0.0010494085727259517,16.57714343070984
contraceptive/lct_gan/history.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.017401078423782666,1.0339260609464513,0.0010806548900643828,0.16402181821875275,0.0,0.0,0.0,0.0,0.01764664788174236,900,225,261.55655670166016,1.1624735853407118,0.29061839633517794,0.1123641776252124,0.016450096456747915,1.6555652437911634,0.0006544369515005302,0.0,0.0,0.0,0.0,0.0,0.016450096456747915,450,113,90.59106540679932,0.8016908443079586,0.20131347868177626,0.10719071906595697
3
+ 1,0.007544516106006793,0.805866371501884,0.00011455891246516556,0.05731394776852944,0.0,0.0,0.0,0.0,0.007670205862442446,900,225,262.5022921562195,1.1666768540276422,0.29166921350691055,0.08921317261954148,0.004202535958288031,0.7761073129848487,4.496906836265167e-05,0.0,0.0,0.0,0.0,0.0,0.004202535958288031,450,113,90.03428149223328,0.796763553028613,0.20007618109385172,0.058546484567521685
4
+ 2,0.006512888989463035,0.7518540992810281,0.0001415838299930615,0.049427424324288344,0.0,0.0,0.0,0.0,0.0066497801841857536,900,225,262.3658037185669,1.1660702387491861,0.29151755968729653,0.09130782820491327,0.0055336219292237525,1.1570013584530916,6.174433674384281e-05,0.0,0.0,0.0,0.0,0.0,0.0055336219292237525,450,113,89.11512207984924,0.788629398936719,0.1980336046218872,0.04651118999973467
5
+ 3,0.006024511704712899,0.4761890591268285,0.0002143375033454278,0.06465743926004507,0.0,0.0,0.0,0.0,0.0061261716642830935,900,225,260.4550771713257,1.1575781207614475,0.2893945301903619,0.09767564491679272,0.004724722134932462,0.7707038955945308,5.621562246012167e-05,0.0,0.0,0.0,0.0,0.0,0.004724722134932462,450,113,87.9228572845459,0.7780783830490787,0.19538412729899088,0.056170098961586444
6
+ 4,0.005286015553380518,0.3363608939476823,7.321617665753689e-05,0.045423433695816334,0.0,0.0,0.0,0.0,0.005374089340039063,900,225,260.1741499900818,1.1563295555114745,0.28908238887786863,0.09759037269486322,0.007981064757849607,1.2652324522273108,9.855836417115466e-05,0.0,0.0,0.0,0.0,0.0,0.007981064757849607,450,113,88.88254165649414,0.7865711651017181,0.19751675923665366,0.04572650106969924
7
+ 5,0.00576256091059703,0.5418789782114618,7.038583156805957e-05,0.02929313911823556,0.0,0.0,0.0,0.0,0.005917770875255681,900,225,262.8189432621002,1.168084192276001,0.29202104806900026,0.09744484349257417,0.0038669909037222774,1.1829836725373764,2.4478697872928952e-05,0.0,0.0,0.0,0.0,0.0,0.0038669909037222774,450,113,90.46491193771362,0.8005744419266693,0.20103313763936362,0.047730479976656824
8
+ 6,0.0038541355246626253,0.3181302580921152,2.260195795584597e-05,0.02967498921504658,0.0,0.0,0.0,0.0,0.003917783604055229,900,225,261.9046974182129,1.1640208774142795,0.2910052193535699,0.10003261071940263,0.0038603629919493365,1.908042863952322,1.606306340802302e-05,0.0,0.0,0.0,0.0,0.0,0.0038603629919493365,450,113,90.71723699569702,0.8028074070415666,0.20159385999043783,0.05183640702017706
9
+ 7,0.003702066924338902,0.2994743646377197,2.3150445486223394e-05,0.018106737517001523,0.0,0.0,0.0,0.0,0.0037678834933709974,900,225,260.3814172744751,1.1572507434421115,0.2893126858605279,0.10073738541454076,0.0028049794745553906,1.4249755061280251,1.3086192913582816e-05,0.0,0.0,0.0,0.0,0.0,0.0028049794745553906,450,113,90.54383492469788,0.8012728754398042,0.20120852205488418,0.054640445479117665
10
+ 8,0.0033579401259905555,0.3882562007228762,1.8571260889600453e-05,0.031065790198707772,0.0,0.0,0.0,0.0,0.0034128941754655293,900,225,261.0693860054016,1.1603083822462295,0.29007709556155736,0.10073215851146314,0.004292678962616871,2.506649698973513,4.666494623541402e-05,0.0,0.0,0.0,0.0,0.0,0.004292678962616871,450,113,89.04813385009766,0.788036582744227,0.1978847418891059,0.05573896699857
11
+ 9,0.0032293959062533557,0.4635296431291696,1.8672718882848967e-05,0.023970585422034167,0.0,0.0,0.0,0.0,0.003291322695731651,900,225,260.9506549835205,1.1597806888156468,0.2899451722039117,0.10359380051907566,0.002744582254672423,2.2099234429464625,2.5944116722460305e-05,0.0,0.0,0.0,0.0,0.0,0.002744582254672423,450,113,89.43556237220764,0.7914651537363508,0.19874569416046142,0.05300469076635926
contraceptive/lct_gan/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
contraceptive/lct_gan/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d0bc27b52d8f3e6cb7d7a93df8356db260bb7d31f833b72a29a856dc5da2b511
3
+ size 47605515
contraceptive/lct_gan/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "none", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.73, "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "head_activation": "softsign", "loss_balancer_beta": 0.67, "loss_balancer_r": 0.943, "tf_activation": "tanh", "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.09, "n_warmup_steps": 100, "Optim": "amsgradw", "fixed_role_model": "lct_gan", "mse_mag": true, "mse_mag_target": 0.65, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation_final": "leakyhardtanh", "tf_num_inds": 128, "ada_d_hid": 1024, "ada_n_layers": 9, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation_final": "leakyhardsigmoid", "models": ["lct_gan"], "max_seconds": 3600}
contraceptive/realtabformer/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ realtabformer,0.010087525346489634,0.03493708366106431,0.002533305302744598,8.234226942062378,0.25083908438682556,6.304386138916016,0.44913777709007263,1.1285437722108327e-05,8.565300703048706,0.03761804848909378,0.12008487433195114,0.050331953912973404,0.09621766954660416,0.016873924061655998,16.799527645111084
contraceptive/realtabformer/history.csv ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.014773491756835332,0.5065707220365463,0.0010022340709901162,2.0442994761798117,0.0,0.0,0.0,0.0,0.015191531291671305,900,225,274.4165141582489,1.2196289518144396,0.3049072379536099,0.11575380941232045,0.007571417718944657,1.554786379107431,0.00017871774661915499,0.0,0.0,0.0,0.0,0.0,0.007571417718944657,450,113,91.69575762748718,0.8114668816591786,0.20376835028330484,0.061720233775296174
3
+ 1,0.007462754625658918,0.9658913450170937,0.00014313012270514552,0.8396876566774315,0.0,0.0,0.0,0.0,0.007641434397470827,900,225,273.0821771621704,1.213698565165202,0.3034246412913005,0.09523055639531877,0.006775144316876928,2.6273896950355087,9.84499310450578e-05,0.0,0.0,0.0,0.0,0.0,0.006775144316876928,450,113,91.63136982917786,0.8108970781343173,0.2036252662870619,0.046712872879249995
4
+ 2,0.004955892372494822,0.7045143008656533,7.975020945901844e-05,0.6091641951931848,0.0,0.0,0.0,0.0,0.005083349947817624,900,225,273.1290547847748,1.2139069101545545,0.30347672753863864,0.09947649084807685,0.0036339714314736838,3.456087975117349,5.305239131944066e-05,0.0,0.0,0.0,0.0,0.0,0.0036339714314736838,450,113,91.59913516044617,0.8106118155791696,0.20355363368988036,0.05323507617829384
5
+ 3,0.0038874727278339883,0.5685213102487542,3.110381803025462e-05,0.44001219677428405,0.0,0.0,0.0,0.0,0.003981491989947648,900,225,273.2383725643158,1.2143927669525147,0.3035981917381287,0.09838743486338192,0.00476446549566592,8.537542689921613,7.457373639723973e-05,0.0,0.0,0.0,0.0,0.0,0.00476446549566592,450,113,91.5707778930664,0.8103608663103222,0.20349061754014758,0.04942838404795353
6
+ 4,0.003344398294769538,0.42805399294468366,2.9547856750660956e-05,0.4788584218091435,0.0,0.0,0.0,0.0,0.0034413283928168108,900,225,273.74941539764404,1.2166640684339736,0.3041660171084934,0.10119029613832632,0.003117711971394278,3.474242146586163,4.511719159815966e-05,0.0,0.0,0.0,0.0,0.0,0.003117711971394278,450,113,91.81963801383972,0.8125631682640684,0.20404364003075492,0.06101849829712141
7
+ 5,0.003131055658159312,0.4019659260851485,2.7654838360707502e-05,0.4308759291966756,0.0,0.0,0.0,0.0,0.0032185360684201846,900,225,275.91382932662964,1.2262836858961317,0.3065709214740329,0.10080154451231162,0.0037931067653052095,2.7604550893965283,4.1587406659311395e-05,0.0,0.0,0.0,0.0,0.0,0.0037931067653052095,450,113,91.9511308670044,0.8137268218318973,0.20433584637112087,0.055654415548614236
8
+ 6,0.0029877640384883206,0.366934466270567,1.5397459504559052e-05,0.43426220549477473,0.0,0.0,0.0,0.0,0.003075523809238803,900,225,275.31064915657043,1.2236028851403131,0.3059007212850783,0.10231422000668115,0.003043726567929197,3.6605247210606398,3.023255854904125e-05,0.0,0.0,0.0,0.0,0.0,0.003043726567929197,450,113,93.64256858825684,0.8286952972412109,0.20809459686279297,0.04299795463994409
9
+ 7,0.0027128174370348763,0.6366902967303084,1.8083823394538535e-05,0.3685302308367358,0.0,0.0,0.0,0.0,0.0027878411782633825,900,225,278.0753390789032,1.2358903959062364,0.3089725989765591,0.0995295613238381,0.0033965686908535037,3.677261768688456,3.01374123970701e-05,0.0,0.0,0.0,0.0,0.0,0.0033965686908535037,450,113,94.00029134750366,0.8318609853761386,0.20888953632778592,0.048798565403043205
10
+ 8,0.00263197087019863,0.38799835741244154,1.5526932464568304e-05,0.40329953748318886,0.0,0.0,0.0,0.0,0.00271295235922379,900,225,275.622722864151,1.2249898793962266,0.30624746984905665,0.10333809573306806,0.0033174797300145856,3.981254602707383,1.7999465464147374e-05,0.0,0.0,0.0,0.0,0.0,0.0033174797300145856,450,113,91.82101058959961,0.8125753149522089,0.20404669019911023,0.048110959254796574
contraceptive/realtabformer/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
contraceptive/realtabformer/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:173cd69ed666cb094528dcde463eb2c0dc1b60c91bcf8c38b1b592d99dbc9a98
3
+ size 50388737
contraceptive/realtabformer/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "none", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.73, "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "head_activation": "softsign", "loss_balancer_beta": 0.67, "loss_balancer_r": 0.943, "tf_activation": "tanh", "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.09, "n_warmup_steps": 100, "Optim": "amsgradw", "fixed_role_model": "realtabformer", "mse_mag": true, "mse_mag_target": 0.65, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation_final": "leakyhardtanh", "tf_num_inds": 128, "ada_d_hid": 1024, "ada_n_layers": 9, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation_final": "leakyhardsigmoid", "models": ["realtabformer"], "max_seconds": 3600}
contraceptive/tab_ddpm_concat/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ tab_ddpm_concat,0.038817155207590465,0.05224676465664414,0.0031409682811326567,12.584968328475952,0.04438546299934387,1.0239506959915161,0.12085322290658951,1.278079525945941e-05,3.8802053928375244,0.04239872843027115,0.13749264180660248,0.05604434013366699,0.08971857279539108,0.03517554700374603,16.465173721313477
contraceptive/tab_ddpm_concat/history.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.029891989165917038,1.363000964352638,0.0028323000936053095,0.03434117138773824,0.0,0.0,0.0,0.0,0.03190343896651433,900,225,264.29759979248047,1.174655999077691,0.29366399976942276,0.12323957710630364,0.018350417958055105,0.8690771571510773,0.0006744714792743404,0.0,0.0,0.0,0.0,0.0,0.018350417958055105,450,113,91.61005926132202,0.810708489038248,0.2035779094696045,0.05547503994332742
3
+ 1,0.026780042523621685,3.719150285677334,0.0013821134704785879,0.06253340480448161,0.0,0.0,0.0,0.0,0.02846383413299918,900,225,265.14594316482544,1.1784264140658909,0.2946066035164727,0.08547261928673834,0.020631530765030118,4.440361282132123,0.0007865790576680511,0.0,0.0,0.0,0.0,0.0,0.020631530765030118,450,113,92.29275822639465,0.8167500727999527,0.205095018280877,0.06341225340801696
4
+ 2,0.009797142112058484,1.345002487987807,0.00018001115952997852,0.014916595776513632,0.0,0.0,0.0,0.0,0.01049884227078615,900,225,267.053644657135,1.1869050873650444,0.2967262718412611,0.07698791329231527,0.0058155485348672506,2.1949546487962954,7.145489165522021e-05,0.0,0.0,0.0,0.0,0.0,0.0058155485348672506,450,113,92.60731506347656,0.8195337616236864,0.20579403347439237,0.03643730922346621
5
+ 3,0.00591874976532482,0.5914836981606755,8.3291523560353e-05,0.015630586181456844,0.0,0.0,0.0,0.0,0.013930044301410413,900,225,265.62826585769653,1.1805700704786513,0.2951425176196628,0.0992426086589694,0.005116828617950281,1.3365860288515323,4.988687947685675e-05,0.0,0.0,0.0,0.0,0.0,0.005116828617950281,450,113,92.46403670310974,0.8182658115319447,0.20547563711802164,0.0723611570861751
6
+ 4,0.006043252464086335,0.6006597002678468,7.193107055212267e-05,0.0193890755618405,0.0,0.0,0.0,0.0,0.006235779779187093,900,225,265.7200925350189,1.1809781890445286,0.29524454726113214,0.09385471865741743,0.0037355842368884218,2.1964674235988046,1.4399180043685393e-05,0.0,0.0,0.0,0.0,0.0,0.0037355842368884218,450,113,92.21309638023376,0.8160451007100333,0.20491799195607505,0.03963511501257596
7
+ 5,0.005593134965747595,0.5571283631080762,5.060047336453723e-05,0.005880406767505014,0.0,0.0,0.0,0.0,0.006675898930989205,900,225,265.53266191482544,1.1801451640658909,0.2950362910164727,0.09367100126213497,0.005213722620262868,1.4734984655637393,3.6257587302249085e-05,0.0,0.0,0.0,0.0,0.0,0.005213722620262868,450,113,92.31340265274048,0.8169327668384113,0.20514089478386774,0.04706298913064916
8
+ 6,0.004435698193054931,0.5322100268975697,2.625556939914274e-05,0.017856327523348026,0.0,0.0,0.0,0.0,0.004762566664511622,900,225,265.40029549598694,1.1795568688710532,0.2948892172177633,0.09591990354160468,0.0032878017918361972,2.197284267162057,1.6023078919465373e-05,0.0,0.0,0.0,0.0,0.0,0.0032878017918361972,450,113,92.23695397377014,0.8162562298563729,0.20497100883060032,0.04823664502257201
9
+ 7,0.004083349550039404,0.2640936150784081,3.919473733837115e-05,0.007234169057984319,0.0,0.0,0.0,0.0,0.004312839023510201,900,225,265.5356845855713,1.1801585981580947,0.29503964953952366,0.10409059958325492,0.0049268581152945344,2.1081838779046094,2.783397738109588e-05,0.0,0.0,0.0,0.0,0.0,0.0049268581152945344,450,113,92.30034828186035,0.8168172414323925,0.20511188507080078,0.03578016849902285
10
+ 8,0.004650288004955251,0.3868287038255513,4.264861073708936e-05,0.014303723979844815,0.0,0.0,0.0,0.0,0.004827073722052672,900,225,265.6125736236572,1.1805003272162544,0.2951250818040636,0.09734993250005775,0.00729073759013166,1.2893098778926795,7.92939556772391e-05,0.0,0.0,0.0,0.0,0.0,0.00729073759013166,450,113,92.1851875782013,0.8157981201610734,0.20485597239600287,0.04885830329442644
11
+ 9,0.005149225248670619,0.3466197312315247,4.727260695603993e-05,0.00807437449758355,0.0,0.0,0.0,0.0,0.005820203127675793,900,225,262.5310835838318,1.1668048159281412,0.2917012039820353,0.10039696198784642,0.0032142305604389142,1.9432415218951797,1.8801717722896813e-05,0.0,0.0,0.0,0.0,0.0,0.0032142305604389142,450,113,89.91775798797607,0.7957323715750095,0.19981723997328016,0.05169614580196155
contraceptive/tab_ddpm_concat/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
contraceptive/tab_ddpm_concat/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b3ac873b73d5d19d404bc1940f245ef136580c00b5a6df3a257860ada70fa96
3
+ size 47482955
contraceptive/tab_ddpm_concat/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "none", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.73, "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "head_activation": "softsign", "loss_balancer_beta": 0.67, "loss_balancer_r": 0.943, "tf_activation": "tanh", "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.09, "n_warmup_steps": 100, "Optim": "amsgradw", "fixed_role_model": "tab_ddpm_concat", "mse_mag": true, "mse_mag_target": 0.65, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation_final": "leakyhardtanh", "tf_num_inds": 128, "ada_d_hid": 1024, "ada_n_layers": 9, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation_final": "leakyhardsigmoid", "models": ["tab_ddpm_concat"], "max_seconds": 3600}
contraceptive/tvae/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ tvae,0.029584258647425615,0.02876510211709655,0.0026995846250792966,12.484404563903809,0.019254347309470177,0.5103664398193359,0.03382880240678787,4.6931290853535756e-05,4.136843204498291,0.03707587346434593,0.11939960718154907,0.05195752531290054,0.09320646524429321,0.024370625615119934,16.6212477684021
contraceptive/tvae/history.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.02038899842772581,0.7097008107104881,0.0013612247566806238,0.2463153438932366,0.0,0.0,0.0,0.0,0.020629282327491737,900,225,260.0577940940857,1.1558124181959364,0.2889531045489841,0.1322056857123971,0.013955928831257755,0.32679080615364564,0.0004168977530578862,0.0,0.0,0.0,0.0,0.0,0.013955928831257755,450,113,87.91127800941467,0.7779759115877405,0.19535839557647705,0.10100915132964079
3
+ 1,0.0060118012685173505,0.5180407320536465,8.332759030961423e-05,0.06306991064869281,0.0,0.0,0.0,0.0,0.00608574254250723,900,225,260.53746366500854,1.1579442829555935,0.28948607073889837,0.09532983317143387,0.004431044542773937,4.37402327783609,5.788037800186684e-05,0.0,0.0,0.0,0.0,0.0,0.004431044542773937,450,113,90.43914699554443,0.8003464335888888,0.20097588221232096,0.04492716361002057
4
+ 2,0.003981401668424951,0.5578023322469762,3.4067052916835076e-05,0.02774549509638746,0.0,0.0,0.0,0.0,0.004037349244463258,900,225,259.8459405899048,1.1548708470662434,0.28871771176656086,0.09886109303269121,0.005920900385180074,3.167294240776193,6.510238445815651e-05,0.0,0.0,0.0,0.0,0.0,0.005920900385180074,450,113,88.16130113601685,0.7801885056284676,0.19591400252448188,0.04382448092585149
5
+ 3,0.003003391056942443,0.5015569428249138,1.917405674469519e-05,0.019288771962617628,0.0,0.0,0.0,0.0,0.003044349568921866,900,225,258.8182637691498,1.1503033945295547,0.28757584863238866,0.09951512091689639,0.0036749178177625354,5.0160115570675385,3.591052332840726e-05,0.0,0.0,0.0,0.0,0.0,0.0036749178177625354,450,113,88.94795989990234,0.7871500876097552,0.1976621331108941,0.050722146361439895
6
+ 4,0.0026951473932907296,0.46114019461917133,1.2756809318339692e-05,0.018491813861118214,0.0,0.0,0.0,0.0,0.0027525624157472826,900,225,260.7371289730072,1.1588316843244764,0.2897079210811191,0.09791627071694367,0.0032899318864413846,2.2600422874186443,1.737884139242202e-05,0.0,0.0,0.0,0.0,0.0,0.0032899318864413846,450,113,89.44571375846863,0.7915549890129967,0.19876825279659696,0.05520178424967123
7
+ 5,0.002675653763451717,0.38926185907018096,1.1465693117683688e-05,0.018026919938856734,0.0,0.0,0.0,0.0,0.002713999592718513,900,225,261.18925762176514,1.1608411449856229,0.2902102862464057,0.10430583260332545,0.00328355419371898,3.905635658147522,2.2026288097660226e-05,0.0,0.0,0.0,0.0,0.0,0.00328355419371898,450,113,89.86624097824097,0.7952764688339908,0.19970275772942436,0.047640036822469756
8
+ 6,0.002428213983172706,0.4563769066303573,1.070831390430716e-05,0.01809621533375725,0.0,0.0,0.0,0.0,0.0024591475264686678,900,225,262.2440469264984,1.1655290974511041,0.29138227436277603,0.10204424848676556,0.003275549128625749,3.9855575082005594,1.929245272713863e-05,0.0,0.0,0.0,0.0,0.0,0.003275549128625749,450,113,90.92798852920532,0.8046724648602241,0.2020621967315674,0.04646150635676953
9
+ 7,0.0023822973380447365,0.44478256372083136,9.005847006119572e-06,0.01762347323496619,0.0,0.0,0.0,0.0,0.002411980113861824,900,225,263.1176962852478,1.1694119834899903,0.2923529958724976,0.10207737949159411,0.0031462487001489435,3.1703384120910294,2.179660826174649e-05,0.0,0.0,0.0,0.0,0.0,0.0031462487001489435,450,113,88.3796284198761,0.7821206054856292,0.19639917426639134,0.047558308290564906
10
+ 8,0.0021461909939373275,0.3496849273867641,8.150395166715502e-06,0.016793633546072266,0.0,0.0,0.0,0.0,0.0021742351456487084,900,225,264.955197095871,1.1775786537594266,0.29439466343985665,0.10324391664730179,0.00289536593284639,3.853903681079736,2.701310116102227e-05,0.0,0.0,0.0,0.0,0.0,0.00289536593284639,450,113,92.01904273033142,0.8143278117728444,0.2044867616229587,0.05265144564250517
11
+ 9,0.002258309722690481,0.5195468653853743,7.549715697529972e-06,0.019823850064721128,0.0,0.0,0.0,0.0,0.0022869745167554355,900,225,265.574116230011,1.1803294054667155,0.2950823513666789,0.10302899842564431,0.0032758567545291347,1.3863763298404954,4.0907979645609235e-05,0.0,0.0,0.0,0.0,0.0,0.0032758567545291347,450,113,89.21779704093933,0.7895380269109675,0.1982617712020874,0.0622333479762918
contraceptive/tvae/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
contraceptive/tvae/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0666886122f68ede68dcce3fa85c996284dd7da4c4e0e081f59def727cc47185
3
+ size 47629899
contraceptive/tvae/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "none", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.73, "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "head_activation": "softsign", "loss_balancer_beta": 0.67, "loss_balancer_r": 0.943, "tf_activation": "tanh", "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.09, "n_warmup_steps": 100, "Optim": "amsgradw", "fixed_role_model": "tvae", "mse_mag": true, "mse_mag_target": 0.65, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "prelu", "tf_d_inner": 512, "tf_n_layers_enc": 3, "tf_n_head": 32, "tf_activation_final": "leakyhardtanh", "tf_num_inds": 128, "ada_d_hid": 1024, "ada_n_layers": 9, "ada_activation": "softsign", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 256, "head_n_layers": 9, "head_n_head": 32, "head_activation_final": "leakyhardsigmoid", "models": ["tvae"], "max_seconds": 3600}
insurance/lct_gan/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ lct_gan,0.0011455664967208923,0.01398279099403832,0.0009290176306413265,6.546527147293091,0.018045689910650253,0.6029329299926758,0.056619517505168915,9.035532457346562e-06,2.36279296875,0.018268784508109093,0.8310969471931458,0.03047979064285755,0.14884799718856812,2.9791326596750878e-06,8.90932011604309
insurance/lct_gan/history.csv ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.03242904957539092,0.8712406825969596,0.006855423091376893,0.38299333698219723,0.0,0.0,0.0,0.0,0.03276979403259853,900,113,157.391695022583,1.3928468586069294,0.17487966113620335,0.13700437690831918,0.006329906290500528,0.46090240039369457,5.679718335689661e-05,0.0,0.0,0.0,0.0,0.0,0.006329906290500528,450,57,52.734307527542114,0.9251632899568792,0.1171873500612047,0.06423525986049258
3
+ 1,0.007500169986807224,1.1092925603606028,0.00019880676637914756,0.05190713240040673,0.0,0.0,0.0,0.0,0.007618296743732773,900,113,157.97713208198547,1.3980277175396945,0.17553014675776163,0.0854555291609954,0.005137449055619072,0.6586513487276945,2.703822330727585e-05,0.0,0.0,0.0,0.0,0.0,0.005137449055619072,450,57,53.497430086135864,0.9385514050199274,0.1188831779691908,0.05167298022200141
4
+ 2,0.004720190377233343,0.5550237497275922,4.577079138076836e-05,0.02987108025078972,0.0,0.0,0.0,0.0,0.004798027821105077,900,113,156.7364845275879,1.3870485356423707,0.17415164947509765,0.09026794963046512,0.004182245211753374,0.24288320739538394,0.00014376330885359007,0.0,0.0,0.0,0.0,0.0,0.004182245211753374,450,57,52.81124806404114,0.9265131239305463,0.11735832903120252,0.07229022658838515
5
+ 3,0.004049827164530547,0.43805476618803796,5.451335522204717e-05,0.03185774852211277,0.0,0.0,0.0,0.0,0.004109540785429999,900,113,155.9995777606964,1.3805272368203223,0.17333286417855157,0.08781432131288854,0.005127925792023436,0.8044920190759293,3.6524204451706925e-05,0.0,0.0,0.0,0.0,0.0,0.005127925792023436,450,57,52.598074197769165,0.92277323153981,0.11688460932837592,0.04705578919393909
6
+ 4,0.004900188291147869,0.6228896175905209,0.00012289669273024944,0.031160058855182596,0.0,0.0,0.0,0.0,0.005004006718000811,900,113,156.05951523780823,1.3810576569717543,0.17339946137534248,0.08326448575980895,0.00360431135011216,0.1391123805354558,0.00017833255974409213,0.0,0.0,0.0,0.0,0.0,0.00360431135011216,450,57,52.58910870552063,0.9226159422021163,0.11686468601226807,0.07722438271402528
7
+ 5,0.00338898796432962,0.3835711898057332,5.93510390443841e-05,0.04150173789097203,0.0,0.0,0.0,0.0,0.0034270015977866325,900,113,156.12185072898865,1.3816092984866253,0.17346872303220962,0.09039239540893947,0.004042578568138803,0.16148795401670143,2.5237032979771928e-05,0.0,0.0,0.0,0.0,0.0,0.004042578568138803,450,57,53.37223672866821,0.9363550303275125,0.11860497050815158,0.07075024978257716
8
+ 6,0.002903090084760657,0.2928461281526556,1.3530937127217902e-05,0.03275197486082713,0.0,0.0,0.0,0.0,0.0029446042016045087,900,113,157.1868932247162,1.3910344533160723,0.174652103583018,0.09083371352305454,0.004153869318348977,0.08165856703591276,0.0004118713491015787,0.0,0.0,0.0,0.0,0.0,0.004153869318348977,450,57,52.67383909225464,0.9241024402149937,0.11705297576056586,0.08548504119869649
9
+ 7,0.0026706402006998866,0.20781586027989052,0.00013795001301429672,0.037672711697717506,0.0,0.0,0.0,0.0,0.0026999003414271607,900,113,155.8716015815735,1.3793947042617123,0.17319066842397055,0.09302399396500756,0.0036898051684774043,0.04890317060488071,0.00016461873778845238,0.0,0.0,0.0,0.0,0.0,0.0036898051684774043,450,57,52.411773443222046,0.9195047972495096,0.11647060765160455,0.08178644566878415
10
+ 8,0.002238100019361203,0.3384687315130162,1.1111032782711483e-05,0.03593144379142258,0.0,0.0,0.0,0.0,0.002262775602414169,900,113,155.77839064598083,1.3785698287254942,0.1730871007177565,0.09169465267157133,0.00233326900155387,0.18033007517983266,6.032218973090211e-05,0.0,0.0,0.0,0.0,0.0,0.00233326900155387,450,57,52.599425315856934,0.9227969353659111,0.11688761181301541,0.07049635971545062
11
+ 9,0.0012590437038711064,0.15078724619232048,4.081550376790824e-06,0.0263645450067189,0.0,0.0,0.0,0.0,0.0012720353449630138,900,113,156.98792576789856,1.3892736793619342,0.1744310286309984,0.09350108472317194,0.001966165854424212,0.3141539510364837,9.730311759192344e-05,0.0,0.0,0.0,0.0,0.0,0.001966165854424212,450,57,53.111929416656494,0.9317882353799385,0.11802650981479221,0.06570776830950197
12
+ 10,0.001532604441874557,0.1929226224414785,3.391608327091929e-06,0.03449563190340996,0.0,0.0,0.0,0.0,0.00154851471255016,900,113,156.23109221458435,1.3825760372972067,0.17359010246064926,0.0936134795996204,0.002808738038454774,0.16682867217481775,0.0001628730561616128,0.0,0.0,0.0,0.0,0.0,0.002808738038454774,450,57,52.47261381149292,0.9205721721314547,0.11660580846998427,0.07371748221646014
13
+ 11,0.0019407396270738294,0.3653798322266867,2.2870681320568346e-05,0.03755757513559527,0.0,0.0,0.0,0.0,0.001960661863623601,900,113,157.5604350566864,1.3943401332450125,0.1750671500629849,0.09302689506779466,0.0022611901594912828,0.223774224152935,9.727328464069852e-05,0.0,0.0,0.0,0.0,0.0,0.0022611901594912828,450,57,52.695061922073364,0.9244747705626906,0.11710013760460748,0.06702947569602545
14
+ 12,0.0015795660438016057,0.20200802482515098,3.2623586478791824e-06,0.03782492588998543,0.0,0.0,0.0,0.0,0.0015964161236964476,900,113,157.62508010864258,1.3949122133508194,0.17513897789849175,0.09343803705301433,0.0019024741732250226,0.16194683409677263,0.00010373163840122158,0.0,0.0,0.0,0.0,0.0,0.0019024741732250226,450,57,53.39478373527527,0.9367505918469345,0.11865507496727837,0.07240477975523263
15
+ 13,0.0012615600261617348,0.17734704436294219,6.4640958073491015e-06,0.032005395059370334,0.0,0.0,0.0,0.0,0.0012749444810389024,900,113,157.78692436218262,1.3963444633821471,0.17531880484686957,0.09475085942025206,0.003454031080505552,0.19643008009194018,0.00048754797153507685,0.0,0.0,0.0,0.0,0.0,0.003454031080505552,450,57,53.61415338516235,0.9405991821958307,0.11914256307813856,0.08090297263850899
16
+ 14,0.001146070581356374,0.12761228089974574,5.205212322840684e-06,0.029757227330572074,0.0,0.0,0.0,0.0,0.001157807001274907,900,113,158.21624660491943,1.4001437752647738,0.1757958295610216,0.09795315121918653,0.0023975990749345835,0.05156347854368074,0.00027527655832503,0.0,0.0,0.0,0.0,0.0,0.0023975990749345835,450,57,53.09577012062073,0.9315047389582584,0.11799060026804606,0.07778614515177253
17
+ 15,0.0008811901032135615,0.11619655411389741,1.1726831726480937e-06,0.026715771118178962,0.0,0.0,0.0,0.0,0.0008903484087042872,900,113,157.73501706123352,1.3958851067365798,0.17526113006803726,0.09735740258036989,0.0022739232058585105,0.07977884817566928,0.0003206384845873516,0.0,0.0,0.0,0.0,0.0,0.0022739232058585105,450,57,52.94484996795654,0.9288570169816938,0.11765522215101454,0.08302696749339239
18
+ 16,0.0007510672794726109,0.07160761223355694,8.719960328579189e-07,0.025627725821816258,0.0,0.0,0.0,0.0,0.0007592219165558668,900,113,155.887836933136,1.3795383799392564,0.17320870770348443,0.10122318941671236,0.001969620921461481,0.14806897564466326,0.0002989729056185993,0.0,0.0,0.0,0.0,0.0,0.001969620921461481,450,57,52.3817937374115,0.9189788374984473,0.11640398608313667,0.07865735263514675
insurance/lct_gan/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
insurance/lct_gan/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a36059ea0c87193f58c8c86e84042e2cf8924cc753130163b6610f3adaaf99a9
3
+ size 38583573
insurance/lct_gan/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.7, "head_final_mul": "identity", "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "loss_balancer_beta": 0.79, "loss_balancer_r": 0.95, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "fixed_role_model": "lct_gan", "mse_mag": true, "mse_mag_target": 0.1, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardsigmoid", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "prelu", "head_activation_final": "softsign", "models": ["lct_gan"], "max_seconds": 3600}
insurance/realtabformer/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ realtabformer,0.010992311685847762,0.00499268272087041,0.0005800863841570736,4.444354057312012,0.20095407962799072,9.563033103942871,0.47066447138786316,4.327647218360653e-07,5.621514081954956,0.015106264501810074,1.2521374225616455,0.02408498153090477,0.1404879093170166,0.001446777256205678,10.065868139266968
insurance/realtabformer/history.csv ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.029822943683765413,0.7290258940829034,0.00741166268384139,4.97963041305542,0.0,0.0,0.0,0.0,0.030589072092229292,900,113,167.66863656044006,1.4837932438976997,0.18629848506715563,0.11875602419990881,0.011320239888348927,0.8666478302036187,0.0016214420898055298,0.0,0.0,0.0,0.0,0.0,0.011320239888348927,450,57,53.24812388420105,0.9341776120035272,0.11832916418711345,0.05855883448793177
3
+ 1,0.008769770529793783,0.7716960565807296,0.00043396698297856246,1.4175501622094049,0.0,0.0,0.0,0.0,0.008993633385075049,900,113,166.8349289894104,1.4764153007912424,0.18537214332156712,0.09298099682921857,0.006308909958720001,0.1587160864532443,0.0010557503438670954,0.0,0.0,0.0,0.0,0.0,0.006308909958720001,450,57,54.493409872055054,0.9560247345974571,0.12109646638234456,0.07172276902322967
4
+ 2,0.0052478324769375225,0.4535665847330644,0.00015836110432902994,0.8015708598825667,0.0,0.0,0.0,0.0,0.005376618540095579,900,113,167.3072214126587,1.4805948797580415,0.18589691268073189,0.09632732554347115,0.005088021330302581,0.6213396075144173,0.00011936115010066849,0.0,0.0,0.0,0.0,0.0,0.005088021330302581,450,57,52.64925193786621,0.9236710866292318,0.11699833763970269,0.052176826527309525
5
+ 3,0.0030236450636746464,0.41981709804675077,4.2389841936938424e-05,0.5216422769096163,0.0,0.0,0.0,0.0,0.0031053333073052473,900,113,167.31541466712952,1.4806673864347746,0.18590601629681058,0.09325551584494852,0.0019911888805735443,0.2196139024221849,1.918494357696417e-05,0.0,0.0,0.0,0.0,0.0,0.0019911888805735443,450,57,53.6285560131073,0.9408518598790754,0.11917456891801623,0.07260068974523037
6
+ 4,0.0016811677904075219,0.24856906805416656,1.2625722286744727e-05,0.3810585221648216,0.0,0.0,0.0,0.0,0.001737628386148976,900,113,167.82943487167358,1.4852162378024212,0.1864771498574151,0.09434814767631809,0.0027509150341696416,0.07755351969834473,2.1412663001758838e-05,0.0,0.0,0.0,0.0,0.0,0.0027509150341696416,450,57,53.66714072227478,0.941528784601312,0.11926031271616618,0.08398696716482702
7
+ 5,0.001128127839474473,0.21778014314879,6.985574985683568e-06,0.28518845240275065,0.0,0.0,0.0,0.0,0.001170292465992841,900,113,160.84526014328003,1.4234093817989384,0.1787169557147556,0.09777768919310342,0.0021597746671694847,0.3930031767821128,3.3223664668092175e-05,0.0,0.0,0.0,0.0,0.0,0.0021597746671694847,450,57,49.775447607040405,0.8732534667901826,0.11061210579342312,0.07324986261511712
8
+ 6,0.0010482495646445184,0.11189875107239877,8.224311742242863e-06,0.2797618282172415,0.0,0.0,0.0,0.0,0.0010887739290612647,900,113,162.833313703537,1.4410027761374955,0.1809259041150411,0.09814084033621887,0.0019042760821240436,0.35190855250582836,1.6172142232353588e-05,0.0,0.0,0.0,0.0,0.0,0.0019042760821240436,450,57,52.383286476135254,0.9190050258971098,0.11640730328030056,0.06503926100732203
9
+ 7,0.001563700584617133,0.21518477240972333,4.5778025960834125e-06,0.30026551425457,0.0,0.0,0.0,0.0,0.0016092181016897989,900,113,165.6556372642517,1.465979090834086,0.1840618191825019,0.09587243578470914,0.003619586681533191,0.5286780524212229,0.00025007503121890473,0.0,0.0,0.0,0.0,0.0,0.003619586681533191,450,57,52.2312707901001,0.9163380840368438,0.11606949064466689,0.07745434307460591
10
+ 8,0.0011870296497040222,0.15783826198738787,2.913096641859988e-06,0.26780749612384375,0.0,0.0,0.0,0.0,0.0012269385414159235,900,113,164.9515302181244,1.4597480550276494,0.18327947802013822,0.0955763464315539,0.000885065957877992,0.12204710721440885,1.0928920943189783e-06,0.0,0.0,0.0,0.0,0.0,0.000885065957877992,450,57,50.95037126541138,0.8938661625510768,0.11322304725646973,0.07468080061912667
11
+ 9,0.0007502347028801321,0.07230353024762727,2.765409718966213e-06,0.2302845541636149,0.0,0.0,0.0,0.0,0.0007828545368586977,900,113,159.93213367462158,1.4153286165895715,0.17770237074957954,0.09794058360620937,0.002954557936366958,0.5175472071741417,0.0001666847069735606,0.0,0.0,0.0,0.0,0.0,0.002954557936366958,450,57,48.75002574920654,0.8552636096352025,0.10833339055379232,0.06763517065790661
12
+ 10,0.0008960025814141975,0.1276295194669632,9.211538751067532e-06,0.24726980176236896,0.0,0.0,0.0,0.0,0.0009317982033179659,900,113,158.89508271217346,1.406151174444013,0.17655009190241497,0.09495367640546992,0.002206301508348487,0.6127840376083674,3.638367311774459e-05,0.0,0.0,0.0,0.0,0.0,0.002206301508348487,450,57,48.69274377822876,0.8542586627759432,0.1082060972849528,0.07679086615609233
13
+ 11,0.0012414162013576263,0.15656802731034372,7.65507348889812e-06,0.2784161967039108,0.0,0.0,0.0,0.0,0.0012826652981392625,900,113,159.0150740146637,1.4072130443775548,0.17668341557184855,0.09698942330031268,0.0015209214665810578,1.1925622708846475,1.541826602484448e-05,0.0,0.0,0.0,0.0,0.0,0.0015209214665810578,450,57,48.32934379577637,0.8478832244873047,0.10739854176839193,0.060592792895540856
14
+ 12,0.000802709650840067,0.1468484333481462,6.78528227740518e-07,0.21540525201294158,0.0,0.0,0.0,0.0,0.000833754398206818,900,113,158.784077167511,1.4051688244912477,0.17642675240834554,0.09867627354981624,0.001989033992609216,0.6257210954253898,2.2637783627021217e-05,0.0,0.0,0.0,0.0,0.0,0.001989033992609216,450,57,48.381258964538574,0.8487940169217294,0.10751390881008573,0.07426875439807445
15
+ 13,0.0006656886564370426,0.08703389815338669,6.284277396025041e-07,0.20659642385111915,0.0,0.0,0.0,0.0,0.0006952917674950893,900,113,157.6981496810913,1.3955588467353213,0.17522016631232368,0.09772866989065589,0.0010757716005254123,0.41428317912765056,2.2639515015956634e-05,0.0,0.0,0.0,0.0,0.0,0.0010757716005254123,450,57,47.70015549659729,0.8368448332736367,0.10600034554799398,0.07492752365049041
16
+ 14,0.0003420950226265834,0.03358766082648436,3.4995474424948957e-07,0.13486777688066165,0.0,0.0,0.0,0.0,0.0003607326176521989,900,113,157.660076379776,1.3952219148652742,0.17517786264419555,0.0989597013500412,0.0008682416723038639,1.067326461550605,1.39864987777777e-05,0.0,0.0,0.0,0.0,0.0,0.0008682416723038639,450,57,47.50433969497681,0.8334094683329264,0.10556519932217068,0.06998455429269948
17
+ 15,0.00032250956939404,0.03796051059134845,2.0489448113398992e-07,0.13379798481861752,0.0,0.0,0.0,0.0,0.00034089527511645834,900,113,158.42128372192383,1.4019582630258747,0.17602364857991537,0.09779412188954585,0.0008600625935489208,1.0569037955788663,1.91862969400389e-05,0.0,0.0,0.0,0.0,0.0,0.0008600625935489208,450,57,48.475810289382935,0.8504528120944375,0.10772402286529541,0.06928659294192728
insurance/realtabformer/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
insurance/realtabformer/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ebd1a49d8056abc7e60ed44b5a73d928fd175a32a62b6f2d0610e0c0ecf1c526
3
+ size 43505805
insurance/realtabformer/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.7, "head_final_mul": "identity", "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "loss_balancer_beta": 0.79, "loss_balancer_r": 0.95, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "fixed_role_model": "realtabformer", "mse_mag": true, "mse_mag_target": 0.1, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardsigmoid", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "prelu", "head_activation_final": "softsign", "models": ["realtabformer"], "max_seconds": 3600}
insurance/tab_ddpm_concat/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ tab_ddpm_concat,0.005787722477074393,0.5310536811323393,0.019623511412314006,6.455421447753906,0.09179345518350601,0.9963166117668152,0.1400327980518341,9.540074643155094e-06,2.324864387512207,0.09191911667585373,4.1357741355896,0.14008395373821259,0.03438407555222511,1.4360357522964478,8.780285835266113
insurance/tab_ddpm_concat/history.csv ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.03592884845203823,7.2996755396477075,0.009430593892273626,0.007427950899711706,0.0,0.0,0.0,0.0,0.09437693562669058,900,113,140.96822214126587,1.2475063906306714,0.15663135793473987,0.05326718922737425,0.014524129554629325,4.156348642464237,0.0009876720450387843,0.0,0.0,0.0,0.0,0.0,0.014524129554629325,450,57,44.241886377334595,0.7761734452163964,0.09831530306074354,0.03724088621113384
3
+ 1,0.014371616853814986,6.227466808899376,0.00030516517503158583,6.165503883999514e-05,0.0,0.0,0.0,0.0,0.2792855604737997,900,113,141.62208795547485,1.2532928137652641,0.15735787550608318,0.03787480419155507,0.01340457120962027,5.29680332468604,0.0002864654652569243,0.0,0.0,0.0,0.0,0.0,0.01340457120962027,450,57,45.41609525680542,0.7967736009965863,0.10092465612623426,0.030410112535352248
4
+ 2,0.01384979708948069,7.497149070959221,0.00021502181714296655,1.949448532994009e-05,0.0,0.0,0.0,0.0,0.11678517651847667,900,113,143.15943694114685,1.2668976720455474,0.15906604104571873,0.033135497126629394,0.014283642331914355,5.459189098905386,0.0005690211627043279,0.0,0.0,0.0,0.0,0.0,0.014283642331914355,450,57,46.33096122741699,0.8128238811827543,0.1029576916164822,0.026298031341611294
5
+ 3,0.013210438237422042,6.6640938428891605,0.00014334700532385921,4.62662972798474e-06,0.0,0.0,0.0,0.0,0.2998970259560479,900,113,146.3652093410492,1.2952673393013203,0.16262801037894356,0.03572286798132468,0.013382001433314548,4.643842804434571,0.00048086737289648024,0.0,0.0,0.0,0.0,0.0,0.013382001433314548,450,57,47.94945430755615,0.8412184966237921,0.10655434290568033,0.029577646339148805
6
+ 4,0.013310110395153363,6.041971506328766,0.00019507275757620039,0.0,0.0,0.0,0.0,0.0,0.013310110395153363,900,113,143.80060958862305,1.2725717662710003,0.15977845509847005,0.038453474558428326,0.013661402131223844,4.1234811326556935,0.00045455218989546767,0.0,0.0,0.0,0.0,0.0,0.013661402131223844,450,57,45.394694089889526,0.7963981419278864,0.10087709797753228,0.0333288926100195
7
+ 5,0.013588898373353812,7.486738423658624,0.0002779256184064716,0.00022828346642199903,0.0,0.0,0.0,0.0,0.08210628287142349,900,113,141.604829788208,1.2531400866213098,0.15733869976467557,0.03336982285620364,0.015309163269638602,11.10999310352408,0.0006529700210521443,0.0,0.0,0.0,0.0,0.0,0.015309163269638602,450,57,45.059311866760254,0.7905142432764957,0.10013180414835612,0.01890142163948009
8
+ 6,0.01360072170642929,9.74690950918578,0.0002635270801760223,9.69009121440144e-06,0.0,0.0,0.0,0.0,3.451531191302153,900,113,140.73290538787842,1.2454239414856496,0.15636989487542047,0.029921204000052097,0.013079761469271034,5.622179901249935,0.00037308515279698264,0.0,0.0,0.0,0.0,0.0,0.013079761469271034,450,57,43.94645428657532,0.7709904260802687,0.09765878730350071,0.024666702521866875
9
+ 7,0.013433132627978921,7.019836858438262,0.0001345649655801632,2.2781116680966485e-05,0.0,0.0,0.0,0.0,0.16974145690082676,900,113,139.68711066246033,1.2361691209067285,0.15520790073606702,0.034294621425524224,0.013359390444432696,5.3989552207959495,0.00037170510294370413,0.0,0.0,0.0,0.0,0.0,0.013359390444432696,450,57,44.08630442619324,0.7734439373016357,0.09796956539154053,0.028002551381002393
10
+ 8,0.013708434053179291,8.164382093871072,0.00011137482917280241,0.0002664083573553297,0.0,0.0,0.0,0.0,0.02415950643726521,900,113,140.14124727249146,1.2401880289601013,0.15571249696943495,0.03384085310688984,0.013566718476617503,6.5259337021869435,0.00033163626410593374,0.0,0.0,0.0,0.0,0.0,0.013566718476617503,450,57,43.685551166534424,0.766413178360253,0.09707900259229872,0.027984272067745525
11
+ 9,0.013571654115286139,8.742722461416001,0.0003018974633938696,1.6721362351543374e-05,0.0,0.0,0.0,0.0,0.027267425186518167,900,113,140.09695625305176,1.2397960730358564,0.15566328472561305,0.030262669457732577,0.01325964125830473,5.029873872576148,0.00043549182230303055,0.0,0.0,0.0,0.0,0.0,0.01325964125830473,450,57,43.787155866622925,0.7681957169582969,0.09730479081471761,0.025762035086620273
12
+ 10,0.014379269344628685,7.366559921767701,0.00037252366738006987,0.0003263944470220142,0.0,0.0,0.0,0.0,0.017774089858867227,900,113,142.94084358215332,1.2649632175411798,0.1588231595357259,0.035403397502954556,0.013996612014921589,7.5169032518093255,0.0005061832475970126,0.0,0.0,0.0,0.0,0.0,0.013996612014921589,450,57,46.45759129524231,0.8150454613200405,0.10323909176720514,0.02128785624773356
13
+ 11,0.013456141208298505,8.885624297614187,0.0001408637665720865,6.312208974526988e-06,0.0,0.0,0.0,0.0,0.12043976681565659,900,113,140.75263118743896,1.2455985060835306,0.15639181243048775,0.03045645008374632,0.015452116089096914,10.887882433234672,0.0007382755149755995,0.0,0.0,0.0,0.0,0.0,0.015452116089096914,450,57,44.30102300643921,0.7772109299375299,0.09844671779208714,0.015029358677566051
14
+ 12,0.013939871930827697,10.302565499742524,0.00030740466281282784,1.2590549886226655e-06,0.0,0.0,0.0,0.0,0.05802687374461028,900,113,140.6320457458496,1.244531378281855,0.15625782860649956,0.026154642290048366,0.013364347005262971,4.240572321283827,0.0003598822188279074,0.0,0.0,0.0,0.0,0.0,0.013364347005262971,450,57,44.313255310058594,0.777425531755414,0.0984739006890191,0.03320887988727344
15
+ 13,0.013616377720609307,7.797813521520289,0.0002454437854082967,8.744494782553778e-05,0.0,0.0,0.0,0.0,0.013847984937537047,900,113,141.74561762809753,1.2543859967088278,0.15749513069788615,0.03314163033084004,0.014651317608594481,6.794933559461737,0.0006282155579800827,0.0,0.0,0.0,0.0,0.0,0.014651317608594481,450,57,47.70662569999695,0.8369583456139815,0.10601472377777099,0.02225242922768781
16
+ 14,0.013695524584295021,8.25407733499283,0.00019777481913188744,1.3750431502962278e-05,0.0,0.0,0.0,0.0,0.021342603873668445,900,113,142.06192708015442,1.2571851953995967,0.157846585644616,0.03127054079328623,0.013534503092782365,4.2766512755933626,0.00026410735567626234,0.0,0.0,0.0,0.0,0.0,0.013534503092782365,450,57,44.758880853652954,0.7852435237482974,0.09946417967478434,0.03563740865833927
17
+ 15,0.013962997518893745,8.13189057305249,0.00029005304642092415,2.4200141843822267e-05,0.0,0.0,0.0,0.0,0.015241746428526111,900,113,141.51434302330017,1.252339318790267,0.15723815891477796,0.03115090290166899,0.013381186781658066,5.749990332730752,0.0003508059210490602,0.0,0.0,0.0,0.0,0.0,0.013381186781658066,450,57,44.66229581832886,0.7835490494443659,0.09924954626295301,0.027999498015433028
insurance/tab_ddpm_concat/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
insurance/tab_ddpm_concat/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fb00eb0a3fc2be866c2dbee35c72985295e305f44a2cf6fffdd10153f525a448
3
+ size 38514197
insurance/tab_ddpm_concat/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.7, "head_final_mul": "identity", "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "loss_balancer_beta": 0.79, "loss_balancer_r": 0.95, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "fixed_role_model": "tab_ddpm_concat", "mse_mag": true, "mse_mag_target": 0.1, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardsigmoid", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "prelu", "head_activation_final": "softsign", "models": ["tab_ddpm_concat"], "max_seconds": 3600}
insurance/tvae/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ tvae,0.005972265340193796,0.004282069878915665,0.00043351158954291826,6.498488903045654,0.0045961132273077965,0.3815285265445709,0.008567946031689644,1.2653150349706266e-07,2.329435348510742,0.012678780592978,0.9071698188781738,0.020820939913392067,0.14727535843849182,2.352470396260742e-08,8.827924251556396
insurance/tvae/history.csv ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.03318693633812169,0.5761553976163866,0.009452911170182633,0.5073339222537147,0.0,0.0,0.0,0.0,0.03357576689010279,900,113,140.07320928573608,1.2395859228826203,0.1556368992063734,0.1377185063222341,0.011029714790444511,0.10201651341138354,0.0008927303574836721,0.0,0.0,0.0,0.0,0.0,0.011029714790444511,450,57,44.65146517753601,0.7833590382023862,0.09922547817230225,0.09806239102934405
3
+ 1,0.008319700596491909,1.3899801381037156,0.00032640089475284936,0.08519980923169189,0.0,0.0,0.0,0.0,0.008420131139250265,900,113,141.50769209861755,1.252280461049713,0.15723076899846394,0.08328526641810889,0.0042552733277746784,0.797521198754859,4.676108655203848e-05,0.0,0.0,0.0,0.0,0.0,0.0042552733277746784,450,57,44.46026659011841,0.7800046770196212,0.09880059242248534,0.05082482615845245
4
+ 2,0.0030213647201243372,0.4467241249097572,2.9881127188979347e-05,0.03057058709777064,0.0,0.0,0.0,0.0,0.003060984404421308,900,113,141.14445447921753,1.249065968842633,0.15682717164357504,0.09463177754881635,0.004963098757434637,0.23140327160492807,0.0003732240848752491,0.0,0.0,0.0,0.0,0.0,0.004963098757434637,450,57,44.70943355560303,0.7843760272912812,0.09935429679022895,0.07549247669317481
5
+ 3,0.0017553741914970386,0.2178003895808945,4.943996762005542e-06,0.02547209452009863,0.0,0.0,0.0,0.0,0.001773931053143719,900,113,141.30070185661316,1.2504486889965767,0.15700077984068128,0.09433466191939284,0.002269126343291848,0.25996375590649135,0.000101374774749564,0.0,0.0,0.0,0.0,0.0,0.002269126343291848,450,57,44.660250663757324,0.7835131695396022,0.09924500147501628,0.06720550957119517
6
+ 4,0.0011990225297955073,0.19844268139087864,6.492888718708172e-06,0.02050420185758008,0.0,0.0,0.0,0.0,0.0012119675411183077,900,113,141.14717173576355,1.2490900153607394,0.15683019081751506,0.09565485499601449,0.0021288806346031683,0.3498717648343612,6.43633937075372e-05,0.0,0.0,0.0,0.0,0.0,0.0021288806346031683,450,57,44.87625551223755,0.7873027282848692,0.09972501224941678,0.06606765008090358
7
+ 5,0.0010727717719338317,0.1799432011435814,1.986284879775273e-06,0.019064169105970197,0.0,0.0,0.0,0.0,0.0010836002508100744,900,113,144.71740865707397,1.2806850323634864,0.1607971207300822,0.0960029606352997,0.0025687832964791193,0.24147355027886128,0.00013711076116941693,0.0,0.0,0.0,0.0,0.0,0.0025687832964791193,450,57,44.92388701438904,0.7881383686734919,0.09983086003197564,0.06918113802053165
8
+ 6,0.0008831181887621319,0.1315059885362297,1.4708352961886883e-06,0.01531516697154277,0.0,0.0,0.0,0.0,0.0008934507078892138,900,113,141.82023167610168,1.255046298018599,0.15757803519566854,0.09242146539674924,0.0017530873860091055,0.304231248577451,7.560512728054956e-05,0.0,0.0,0.0,0.0,0.0,0.0017530873860091055,450,57,44.82837462425232,0.7864627127061811,0.09961861027611627,0.06672564117858808
9
+ 7,0.0005385009764318562,0.07839578661780303,7.117195005359047e-07,0.012112246143321197,0.0,0.0,0.0,0.0,0.0005440399935145655,900,113,141.7691547870636,1.2545942901510052,0.15752128309673732,0.09718591164368971,0.0023827175304500592,2.154619756286098,0.00013736589136319527,0.0,0.0,0.0,0.0,0.0,0.0023827175304500592,450,57,46.83988165855408,0.8217523097991943,0.10408862590789796,0.06465611442723584
10
+ 8,0.0004240491726927252,0.046745863687660795,1.4920491366082138e-07,0.010413902575771013,0.0,0.0,0.0,0.0,0.0004284947737728039,900,113,147.48687982559204,1.3051936267751507,0.1638743109173245,0.0963774272529161,0.0019166796955202396,0.40897835704338553,4.469346806295368e-05,0.0,0.0,0.0,0.0,0.0,0.0019166796955202396,450,57,45.08631491661072,0.7909879809931705,0.1001918109258016,0.0743747265813382
11
+ 9,0.0005618307212070148,0.06497306443851963,5.931063866807492e-07,0.012489941705846124,0.0,0.0,0.0,0.0,0.0005679361823422369,900,113,142.84351229667664,1.2641018787316516,0.15871501366297405,0.09641832253376467,0.0017379091629603257,0.4798262932258854,2.2975261447157275e-05,0.0,0.0,0.0,0.0,0.0,0.0017379091629603257,450,57,46.893903970718384,0.822700069661726,0.10420867549048529,0.06686989114856706
12
+ 10,0.00031843584507846066,0.034923846742066104,2.1399565589458585e-07,0.009167301410602199,0.0,0.0,0.0,0.0,0.00032186484126011945,900,113,144.4654381275177,1.2784552046682982,0.16051715347501966,0.10122920503526663,0.0017598816551940722,0.5481770304940762,4.9227021056914934e-05,0.0,0.0,0.0,0.0,0.0,0.0017598816551940722,450,57,45.89448928833008,0.8051664787426329,0.10198775397406684,0.0716182768922871
13
+ 11,0.0002415686168306921,0.022000115441331264,3.5452563263274956e-08,0.007707212487649586,0.0,0.0,0.0,0.0,0.00024431025462238017,900,113,141.62232375144958,1.253294900455306,0.15735813750161065,0.09990070720689487,0.001438923770741288,0.40616633773342614,1.6548902490660263e-05,0.0,0.0,0.0,0.0,0.0,0.001438923770741288,450,57,44.89892363548279,0.7877004146575928,0.09977538585662842,0.07267187253098216
insurance/tvae/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
insurance/tvae/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:86800e68cd854a15a18fb0b685da25ff930b8062fcf84265ce348c33e151e6c1
3
+ size 38612117
insurance/tvae/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "fixnorm", "grad_clip": 0.7, "head_final_mul": "identity", "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.05, "loss_balancer_beta": 0.79, "loss_balancer_r": 0.95, "dataset_size": 2048, "batch_size": 8, "epochs": 100, "n_warmup_steps": 100, "Optim": "diffgrad", "fixed_role_model": "tvae", "mse_mag": true, "mse_mag_target": 0.1, "mse_mag_multiply": true, "d_model": 256, "attn_activation": "leakyrelu", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardsigmoid", "tf_activation_final": "leakyhardsigmoid", "tf_num_inds": 32, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "relu", "ada_activation_final": "softsign", "head_d_hid": 128, "head_n_layers": 9, "head_n_head": 64, "head_activation": "prelu", "head_activation_final": "softsign", "models": ["tvae"], "max_seconds": 3600}
treatment/lct_gan/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ lct_gan,0.0,0.00044872478942392847,0.004746775878468595,11.582196950912476,0.09063062071800232,1.6706515550613403,0.14325737953186035,3.869256761390716e-05,6.415107011795044,0.04724160209298134,4645954.5,0.06889685243368149,0.23633873462677002,0.0008537429966963828,17.99730396270752
treatment/lct_gan/history.csv ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.08288680652415173,14.884825719568891,0.014639717982716328,0.007401900000145865,0.0,0.0,0.0,0.0,0.406318084973221,900,225,398.05458784103394,1.7691315015157063,0.4422828753789266,0.04329945017169747,0.01646603072894332,1.1389748531219954,0.001041307360675295,0.0,0.0,0.0,0.0,0.0,0.01646603072894332,450,113,104.75376605987549,0.9270244784059778,0.23278614679972331,0.1397827780404771
3
+ 1,0.014255170632645281,0.24843680457298686,0.000636702568035041,0.15559853479266167,0.0,0.0,0.0,0.0,0.014447312605981198,900,225,406.69940519332886,1.8075529119703504,0.4518882279925876,0.22602122453765736,0.011084747278914115,0.774564493527634,0.00036442376687365014,0.0,0.0,0.0,0.0,0.0,0.011084747278914115,450,113,105.70433187484741,0.9354365652641364,0.23489851527743869,0.13189282911609182
4
+ 2,0.009175471019561883,0.14389958513339807,0.0003093209020328993,0.16885624952562567,0.0,0.0,0.0,0.0,0.00928754332613001,900,225,406.6975419521332,1.8075446308983696,0.4518861577245924,0.23166270198714403,0.00785602382393045,1.3146459211567783,0.0001525110592525784,0.0,0.0,0.0,0.0,0.0,0.00785602382393045,450,113,105.51167917251587,0.9337316740930608,0.23447039816114637,0.11703323267914861
5
+ 3,0.006946129192502769,0.16631651740127382,0.00011629207969720338,0.1469966934973167,0.0,0.0,0.0,0.0,0.007030756754486194,900,225,404.7738826274872,1.798995033899943,0.44974875847498574,0.23086450531949393,0.006420711345431553,1.1872994360646385,8.72488180067034e-05,0.0,0.0,0.0,0.0,0.0,0.006420711345431553,450,113,103.53712010383606,0.9162577000339475,0.2300824891196357,0.1300600932704221
6
+ 4,0.005421485699508695,0.08202415918292068,0.00010085892813725515,0.13527608269825578,0.0,0.0,0.0,0.0,0.005492081102662875,900,225,403.6879127025604,1.7941685009002686,0.44854212522506715,0.23655629260775943,0.009256900170618995,2.8233697297467697,0.0006465031184101571,0.0,0.0,0.0,0.0,0.0,0.009256900170618995,450,113,101.8208520412445,0.9010694870906594,0.22626856009165447,0.13276277751503135
7
+ 5,0.0043073366452492535,0.19139563928810466,8.40804457109845e-05,0.12489897313269062,0.0,0.0,0.0,0.0,0.0043637448957360905,900,225,398.688072681427,1.7719469896952311,0.4429867474238078,0.23179641341906973,0.0055422348428186925,1.4432227016979617,5.823207058893942e-05,0.0,0.0,0.0,0.0,0.0,0.0055422348428186925,450,113,101.09947466850281,0.8946856165354231,0.22466549926333956,0.11363168490392674
8
+ 6,0.003263339866756117,0.3763446019548343,2.3110432702086737e-05,0.10734513330583771,0.0,0.0,0.0,0.0,0.0033083394544084713,900,225,398.69307565689087,1.7719692251417372,0.4429923062854343,0.24115597604735134,0.00555163197664519,1.1181977761929087,0.00010318647899446903,0.0,0.0,0.0,0.0,0.0,0.00555163197664519,450,113,101.25119686126709,0.8960282908076733,0.22500265969170463,0.11773095095579861
treatment/lct_gan/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
treatment/lct_gan/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7605c6a45b313a0595da3f1d70c51c9d4cd98069d6756acf793194c058040638
3
+ size 74778241
treatment/lct_gan/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "torch", "grad_clip": 0.8, "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.1, "loss_balancer_beta": 0.73, "loss_balancer_r": 0.94, "dataset_size": 2048, "batch_size": 4, "epochs": 100, "lr_mul": 0.04, "n_warmup_steps": 220, "Optim": "diffgrad", "fixed_role_model": "lct_gan", "mse_mag": true, "mse_mag_target": 0.2, "mse_mag_multiply": true, "d_model": 512, "attn_activation": "leakyhardsigmoid", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardtanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "selu", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 128, "head_n_layers": 8, "head_n_head": 64, "head_activation": "leakyhardsigmoid", "head_activation_final": "leakyhardsigmoid", "models": ["lct_gan"], "max_seconds": 3600}
treatment/realtabformer/eval.csv ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ ,avg_g_cos_loss,avg_g_mag_loss,avg_loss,grad_duration,grad_mae,grad_mape,grad_rmse,mean_pred_loss,pred_duration,pred_mae,pred_mape,pred_rmse,pred_std,std_loss,total_duration
2
+ realtabformer,0.0,0.3809523773193367,0.005029816041935322,6.139832019805908,0.46518340706825256,7.454497337341309,0.9615514278411865,0.00015098779113031924,27.904475212097168,0.043616220355033875,2446463.5,0.07092119008302689,0.24172629415988922,0.00011316310701658949,34.044307231903076
treatment/realtabformer/history.csv ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ ,avg_role_model_loss_train,avg_role_model_std_loss_train,avg_role_model_mean_pred_loss_train,avg_role_model_g_mag_loss_train,avg_role_model_g_cos_loss_train,avg_non_role_model_g_mag_loss_train,avg_non_role_model_g_cos_loss_train,avg_non_role_model_embed_loss_train,avg_loss_train,n_size_train,n_batch_train,duration_train,duration_batch_train,duration_size_train,avg_pred_std_train,avg_role_model_loss_test,avg_role_model_std_loss_test,avg_role_model_mean_pred_loss_test,avg_role_model_g_mag_loss_test,avg_role_model_g_cos_loss_test,avg_non_role_model_g_mag_loss_test,avg_non_role_model_g_cos_loss_test,avg_non_role_model_embed_loss_test,avg_loss_test,n_size_test,n_batch_test,duration_test,duration_batch_test,duration_size_test,avg_pred_std_test
2
+ 0,0.03997775016380672,4.1784445257359195,0.003654398211962573,1.125379170803628,0.0,0.0,0.0,0.0,0.04061377539565386,900,450,674.5398631095886,1.4989774735768635,0.7494887367884318,0.13649817022379668,0.014926234884779453,5.134242437605607,0.0010615593935528929,0.0,0.0,0.0,0.0,0.0,0.014926234884779453,450,225,201.71023845672607,0.8964899486965603,0.44824497434828015,0.10115677190537907
3
+ 1,0.013389052229801714,1.3475211419678794,0.0008060232592924967,0.8510899391982396,0.0,0.0,0.0,0.0,0.013631856909132883,900,450,668.8134686946869,1.4862521526548598,0.7431260763274299,0.1980984934745033,0.008773729122132234,3.148535844511999,0.0007421919603760864,0.0,0.0,0.0,0.0,0.0,0.008773729122132234,450,225,197.29309678077698,0.8768582079145644,0.4384291039572822,0.0931849179521747
4
+ 2,0.007238608380309618,1.7181790947342477,0.0001944728850099263,0.5850073061873101,0.0,0.0,0.0,0.0,0.0073909809483623376,900,450,671.3675940036774,1.4919279866748385,0.7459639933374193,0.18768813919843738,0.011670945049314128,4.705433707986743,0.0008353222444575463,0.0,0.0,0.0,0.0,0.0,0.011670945049314128,450,225,201.83496594429016,0.8970442930857341,0.44852214654286704,0.08814679903484034
5
+ 3,0.007349486502970738,1.3100048318551674,0.0001732898741345901,0.552599703557272,0.0,0.0,0.0,0.0,0.007496255471258072,900,450,675.7493937015533,1.5016653193367853,0.7508326596683926,0.1919305363571362,0.008832458678805387,4.634036721558722,0.000772223847405924,0.0,0.0,0.0,0.0,0.0,0.008832458678805387,450,225,201.21838569641113,0.8943039364284939,0.44715196821424696,0.07392007318481268
treatment/realtabformer/mlu-eval.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
treatment/realtabformer/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ca83017195316a2f5210cdd5f311db787ee889cbbb92e8fd4dad4f7a666bb2b7
3
+ size 78481207
treatment/realtabformer/params.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {"Body": "twin_encoder", "loss_balancer_meta": true, "loss_balancer_log": false, "loss_balancer_lbtw": false, "pma_skip_small": false, "isab_skip_small": false, "layer_norm": false, "pma_layer_norm": false, "attn_residual": true, "tf_n_layers_dec": false, "tf_isab_rank": 0, "tf_lora": false, "tf_layer_norm": false, "tf_pma_start": -1, "ada_n_seeds": 0, "head_n_seeds": 0, "tf_pma_low": 16, "gradient_penalty_kwargs": {"mag_loss": true, "mse_mag": true, "mag_corr": false, "seq_mag": false, "cos_loss": false, "mse_mag_kwargs": {"target": 1.0, "multiply": true, "forgive_over": true}, "mag_corr_kwargs": {"only_sign": false}, "cos_loss_kwargs": {"only_sign": true, "cos_matrix": false}}, "dropout": 0, "combine_mode": "diff_left", "tf_isab_mode": "separate", "grad_loss_fn": "mae", "single_model": true, "bias": true, "bias_final": true, "pma_ffn_mode": "shared", "patience": 10, "inds_init_mode": "torch", "grad_clip": 0.8, "gradient_penalty_mode": "ALL", "synth_data": 2, "bias_lr_mul": 1.0, "bias_weight_decay": 0.1, "loss_balancer_beta": 0.73, "loss_balancer_r": 0.94, "dataset_size": 2048, "batch_size": 2, "epochs": 100, "lr_mul": 0.04, "n_warmup_steps": 220, "Optim": "diffgrad", "fixed_role_model": "realtabformer", "mse_mag": true, "mse_mag_target": 0.2, "mse_mag_multiply": true, "d_model": 512, "attn_activation": "leakyhardsigmoid", "tf_d_inner": 512, "tf_n_layers_enc": 4, "tf_n_head": 64, "tf_activation": "leakyhardtanh", "tf_activation_final": "leakyhardtanh", "tf_num_inds": 64, "ada_d_hid": 1024, "ada_n_layers": 7, "ada_activation": "selu", "ada_activation_final": "leakyhardsigmoid", "head_d_hid": 128, "head_n_layers": 8, "head_n_head": 64, "head_activation": "leakyhardsigmoid", "head_activation_final": "leakyhardsigmoid", "models": ["realtabformer"], "max_seconds": 3600}