Spaces:
Sleeping
Sleeping
File size: 2,296 Bytes
79f3d28 e571d8c 79f3d28 e571d8c 79f3d28 e571d8c 79f3d28 e571d8c f308820 e571d8c 79f3d28 fd4ce9c 79f3d28 c504867 fd4ce9c c504867 e571d8c c504867 3563daa e571d8c c504867 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from dataclasses import dataclass, field
from matplotlib import pyplot as plt
import matplotlib
import seaborn as sns
from typing import Callable
import numpy as np
import base64
import io
sns.set()
matplotlib.use("Agg")
@dataclass
class NeuralNetwork:
epochs: int
learning_rate: float
activation_func: Callable
func_prime: Callable
hidden_size: int
w1: np.array
w2: np.array
b1: np.array
b2: np.array
mse: float = 0
loss_history: list = field(
default_factory=lambda: [],
)
plt_data = None
def predict(self, x: np.array) -> np.array:
n1 = self.compute_node(x, self.w1, self.b1, self.activation_func)
return self.compute_node(n1, self.w2, self.b2, self.activation_func)
def set_loss_hist(self, loss_hist: list) -> None:
self.loss_history = loss_hist
def eval(self, X_test, y_test) -> None:
self.mse = np.mean((self.predict(X_test) - y_test) ** 2)
def set_plot_data(self, plot_data):
self.plt_data = plot_data
def plot(self):
sns.set()
fig, ax = plt.subplots()
sns.lineplot(
x=np.arange(len(self.loss_history)),
y=self.loss_history,
ax=ax,
)
plt.ylabel("Loss")
plt.xlabel("Epoch")
plt.title("Loss / Epoch")
buffer = io.BytesIO()
plt.savefig(buffer, format='png')
buffer.seek(0)
plt_data = buffer.read()
plt_data_encoded = base64.b64encode(plt_data).decode('utf-8')
self.set_plot_data(plt_data_encoded)
plt.close()
@staticmethod
def compute_node(arr, w, b, func) -> np.array:
return func(np.dot(arr, w) + b)
@classmethod
def from_dict(cls, dct):
return cls(**dct)
def to_dict(self) -> dict:
return {
# "w1": self.w1.tolist(),
# "w2": self.w2.tolist(),
# "b1": self.b1.tolist(),
# "b2": self.b2.tolist(),
"epochs": self.epochs,
"learning_rate": self.learning_rate,
"activation_func": self.activation_func.__name__,
"func_prime": self.func_prime.__name__,
"hidden_size": self.hidden_size,
"mse": self.mse,
"plt_data": self.plt_data,
}
|