Jensen-holm commited on
Commit
2fc4a94
·
1 Parent(s): ceca02c

added loss history in the return to the user so they can make graphs and

Browse files
Files changed (1) hide show
  1. nn/train.py +5 -4
nn/train.py CHANGED
@@ -45,8 +45,8 @@ def train(nn: NN) -> dict:
45
  )
46
  # compute error & store it
47
  error = y_hat - y_train
48
- mse = mean_squared_error(y=y_train, y_hat=y_hat)
49
- loss_hist.append(mse)
50
 
51
  # compute derivatives of weights & biases
52
  # update weights & biases using gradient descent after
@@ -81,8 +81,9 @@ def train(nn: NN) -> dict:
81
  )
82
 
83
  return {
84
- "log loss": log_loss(y_true=y_test, y_pred=y_hat),
85
- "accuracy": accuracy_score(y_true=y_test, y_pred=y_hat)
 
86
  }
87
 
88
 
 
45
  )
46
  # compute error & store it
47
  error = y_hat - y_train
48
+ loss = log_loss(y_true=y_train, y_pred=y_hat)
49
+ loss_hist.append(loss)
50
 
51
  # compute derivatives of weights & biases
52
  # update weights & biases using gradient descent after
 
81
  )
82
 
83
  return {
84
+ "loss_hist": loss_hist,
85
+ "log_loss": log_loss(y_true=y_test, y_pred=y_hat),
86
+ "accuracy": accuracy_score(y_true=y_test, y_pred=y_hat),
87
  }
88
 
89