Jensen-holm commited on
Commit
b6a95e0
·
1 Parent(s): 6d15033

making better loss history / epoch plot

Browse files
Files changed (1) hide show
  1. neural_network/plot.py +5 -1
neural_network/plot.py CHANGED
@@ -9,12 +9,16 @@ from neural_network.neural_network import NeuralNetwork
9
  matplotlib.use("Agg")
10
 
11
  def plot(model: NeuralNetwork) -> None:
 
12
  fig, ax = plt.subplots()
13
- sns.scatterplot(
14
  x=np.arange(len(model.loss_history)),
15
  y=model.loss_history,
16
  ax=ax,
17
  )
 
 
 
18
  buf = io.BytesIO()
19
  fig.savefig(buf, format="png")
20
  plt.close(fig)
 
9
  matplotlib.use("Agg")
10
 
11
  def plot(model: NeuralNetwork) -> None:
12
+ sns.set()
13
  fig, ax = plt.subplots()
14
+ sns.lineplot(
15
  x=np.arange(len(model.loss_history)),
16
  y=model.loss_history,
17
  ax=ax,
18
  )
19
+ plt.ylabel("Loss")
20
+ plt.xlabel("Epoch")
21
+ plt.title("Loss / Epoch")
22
  buf = io.BytesIO()
23
  fig.savefig(buf, format="png")
24
  plt.close(fig)