Commit
·
1e811f2
1
Parent(s):
14c7b13
Fixed bug in hparam tuning for XGBoost + added additional metrics to XGBoost reports
Browse files
protac_degradation_predictor/optuna_utils_xgboost.py
CHANGED
@@ -90,7 +90,7 @@ def train_and_evaluate_xgboost(
|
|
90 |
)
|
91 |
|
92 |
# Evaluate model
|
93 |
-
val_pred = model.
|
94 |
val_pred_binary = (val_pred > 0.5).astype(int)
|
95 |
metrics = {
|
96 |
'val_acc': accuracy_score(y_val, val_pred_binary),
|
@@ -102,7 +102,7 @@ def train_and_evaluate_xgboost(
|
|
102 |
preds = {'val_pred': val_pred}
|
103 |
|
104 |
if test_df is not None:
|
105 |
-
test_pred = model.
|
106 |
test_pred_binary = (test_pred > 0.5).astype(int)
|
107 |
metrics.update({
|
108 |
'test_acc': accuracy_score(y_test, test_pred_binary),
|
@@ -335,6 +335,7 @@ def xgboost_hyperparameter_tuning_and_training(
|
|
335 |
|
336 |
# Get the majority vote for the test predictions
|
337 |
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
|
|
|
338 |
majority_vote_report = pd.DataFrame([majority_vote_metrics])
|
339 |
majority_vote_report['model_type'] = 'XGBoost'
|
340 |
|
|
|
90 |
)
|
91 |
|
92 |
# Evaluate model
|
93 |
+
val_pred = model.predict(dval)
|
94 |
val_pred_binary = (val_pred > 0.5).astype(int)
|
95 |
metrics = {
|
96 |
'val_acc': accuracy_score(y_val, val_pred_binary),
|
|
|
102 |
preds = {'val_pred': val_pred}
|
103 |
|
104 |
if test_df is not None:
|
105 |
+
test_pred = model.predict(dtest)
|
106 |
test_pred_binary = (test_pred > 0.5).astype(int)
|
107 |
metrics.update({
|
108 |
'test_acc': accuracy_score(y_test, test_pred_binary),
|
|
|
335 |
|
336 |
# Get the majority vote for the test predictions
|
337 |
majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
|
338 |
+
majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
|
339 |
majority_vote_report = pd.DataFrame([majority_vote_metrics])
|
340 |
majority_vote_report['model_type'] = 'XGBoost'
|
341 |
|