ribesstefano commited on
Commit
367cf2c
·
1 Parent(s): 8d92d67

Fixed bug on averaging predictions in package code

Browse files
protac_degradation_predictor/protac_degradation_predictor.py CHANGED
@@ -93,11 +93,12 @@ def get_protac_active_proba(
93
  prescaled_embeddings=False, # Normalization performed by the model
94
  )
95
  preds[ckpt_path] = sigmoid(pred).detach().cpu().numpy().flatten()
96
- axis = 1 if isinstance(protac_smiles, list) else None
 
97
  return {
98
- 'preds': np.array(list(preds.values())),
99
- 'mean': np.mean(list(preds.values()), axis=axis),
100
- 'majority_vote': np.mean(list(preds.values()), axis=axis) > 0.5,
101
  }
102
 
103
 
 
93
  prescaled_embeddings=False, # Normalization performed by the model
94
  )
95
  preds[ckpt_path] = sigmoid(pred).detach().cpu().numpy().flatten()
96
+ # NOTE: The predictions array has shape: (n_models, batch_size)
97
+ preds = np.array(list(preds.values()))
98
  return {
99
+ 'preds': preds,
100
+ 'mean': np.mean(preds, axis=0),
101
+ 'majority_vote': np.mean(preds, axis=0) > 0.5,
102
  }
103
 
104