ribesstefano commited on
Commit
1e811f2
·
1 Parent(s): 14c7b13

Fixed bug in hparam tuning for XGBoost + added additional metrics to XGBoost reports

Browse files
protac_degradation_predictor/optuna_utils_xgboost.py CHANGED
@@ -90,7 +90,7 @@ def train_and_evaluate_xgboost(
90
  )
91
 
92
  # Evaluate model
93
- val_pred = model.inplace_predict(dval)
94
  val_pred_binary = (val_pred > 0.5).astype(int)
95
  metrics = {
96
  'val_acc': accuracy_score(y_val, val_pred_binary),
@@ -102,7 +102,7 @@ def train_and_evaluate_xgboost(
102
  preds = {'val_pred': val_pred}
103
 
104
  if test_df is not None:
105
- test_pred = model.inplace_predict(dtest)
106
  test_pred_binary = (test_pred > 0.5).astype(int)
107
  metrics.update({
108
  'test_acc': accuracy_score(y_test, test_pred_binary),
@@ -335,6 +335,7 @@ def xgboost_hyperparameter_tuning_and_training(
335
 
336
  # Get the majority vote for the test predictions
337
  majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
 
338
  majority_vote_report = pd.DataFrame([majority_vote_metrics])
339
  majority_vote_report['model_type'] = 'XGBoost'
340
 
 
90
  )
91
 
92
  # Evaluate model
93
+ val_pred = model.predict(dval)
94
  val_pred_binary = (val_pred > 0.5).astype(int)
95
  metrics = {
96
  'val_acc': accuracy_score(y_val, val_pred_binary),
 
102
  preds = {'val_pred': val_pred}
103
 
104
  if test_df is not None:
105
+ test_pred = model.predict(dtest)
106
  test_pred_binary = (test_pred > 0.5).astype(int)
107
  metrics.update({
108
  'test_acc': accuracy_score(y_test, test_pred_binary),
 
335
 
336
  # Get the majority vote for the test predictions
337
  majority_vote_metrics = get_majority_vote_metrics(test_preds, test_df, active_label)
338
+ majority_vote_metrics.update(get_dataframe_stats(train_val_df, test_df=test_df, active_label=active_label))
339
  majority_vote_report = pd.DataFrame([majority_vote_metrics])
340
  majority_vote_report['model_type'] = 'XGBoost'
341