Jensen-holm commited on
Commit
fbe515e
·
1 Parent(s): 2c9d2bf

believe to have created a function to build a svg plot and send it over the network

Browse files
neural_network/main.py CHANGED
@@ -3,6 +3,7 @@ import numpy as np
3
 
4
  from neural_network.opts import activation
5
  from neural_network.backprop import bp
 
6
 
7
 
8
  def init(X: np.array, hidden_size: int) -> dict:
@@ -46,4 +47,7 @@ def main(
46
  X_test=X_test,
47
  y_test=y_test,
48
  )
 
 
 
49
  return model.to_dict()
 
3
 
4
  from neural_network.opts import activation
5
  from neural_network.backprop import bp
6
+ from neural_network.plot import plot
7
 
8
 
9
  def init(X: np.array, hidden_size: int) -> dict:
 
47
  X_test=X_test,
48
  y_test=y_test,
49
  )
50
+
51
+ plot(model=model)
52
+
53
  return model.to_dict()
neural_network/neural_network.py CHANGED
@@ -19,6 +19,7 @@ class NeuralNetwork:
19
  loss_history: list = field(
20
  default_factory=lambda: [],
21
  )
 
22
 
23
  def predict(self, x: np.array) -> np.array:
24
  n1 = self.compute_node(x, self.w1, self.b1, self.activation_func)
@@ -51,4 +52,5 @@ class NeuralNetwork:
51
  "hidden_size": self.hidden_size,
52
  "mse": self.mse,
53
  "loss_history": self.loss_history,
 
54
  }
 
19
  loss_history: list = field(
20
  default_factory=lambda: [],
21
  )
22
+ plot = None
23
 
24
  def predict(self, x: np.array) -> np.array:
25
  n1 = self.compute_node(x, self.w1, self.b1, self.activation_func)
 
52
  "hidden_size": self.hidden_size,
53
  "mse": self.mse,
54
  "loss_history": self.loss_history,
55
+ "plot": self.plot,
56
  }
neural_network/plot.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import base64
3
+ import io
4
+ import seaborn as sns
5
+ import matplotlib
6
+ import matplotlib.pyplot as plt
7
+ from neural_network.neural_network import NeuralNetwork
8
+
9
+ matplotlib.use("Agg")
10
+
11
+ def plot(model: NeuralNetwork) -> None:
12
+ _ = sns.scatterplot(
13
+ x=np.arange(len(model.loss_history)),
14
+ y=model.loss_history,
15
+ )
16
+ buf = io.BytesIO()
17
+ plt.savefig(buf, format="svg")
18
+ plt.clf()
19
+ buf.seek(0)
20
+ plot_data = base64.b64encode(buf.getvalue()).decode("utf-8")
21
+ model.plot = plot_data