Jensen-holm commited on
Commit
663083f
·
1 Parent(s): fbe515e

think that the way we had teh plot function setup may have been causing some problems with buffers in the cloud

Browse files
Files changed (1) hide show
  1. neural_network/plot.py +4 -4
neural_network/plot.py CHANGED
@@ -9,13 +9,13 @@ from neural_network.neural_network import NeuralNetwork
9
  matplotlib.use("Agg")
10
 
11
  def plot(model: NeuralNetwork) -> None:
12
- _ = sns.scatterplot(
 
13
  x=np.arange(len(model.loss_history)),
14
  y=model.loss_history,
 
15
  )
16
  buf = io.BytesIO()
17
- plt.savefig(buf, format="svg")
18
- plt.clf()
19
- buf.seek(0)
20
  plot_data = base64.b64encode(buf.getvalue()).decode("utf-8")
21
  model.plot = plot_data
 
9
  matplotlib.use("Agg")
10
 
11
  def plot(model: NeuralNetwork) -> None:
12
+ fig, ax = plt.subplots()
13
+ sns.scatterplot(
14
  x=np.arange(len(model.loss_history)),
15
  y=model.loss_history,
16
+ ax=ax,
17
  )
18
  buf = io.BytesIO()
19
+ fig.savefig(buf, format="svg")
 
 
20
  plot_data = base64.b64encode(buf.getvalue()).decode("utf-8")
21
  model.plot = plot_data