File size: 2,202 Bytes
0a04cd7
 
0987346
 
 
 
 
 
0a04cd7
0987346
 
 
 
0a04cd7
 
0987346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a04cd7
 
0987346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0a04cd7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
package nn

import (
	"fmt"

	"gonum.org/v1/gonum/mat"
)

func (nn *NN) Backprop() {
	var (
		activation = *nn.ActivationFunc
		// lossHist   []float64
	)

	for i := 0; i < nn.Epochs; i++ {
		// compute output with current w + b
		// then compute loss & backprop
		hiddenOutput, err := computeOutput(
			nn.XTrain,
			nn.Wh,
			nn.Bh,
			activation,
		)
		if err != nil {
			fmt.Printf("error computing hidden output: %v", err)
		}

		yHat, err := computeOutput(
			hiddenOutput,
			nn.Wo,
			nn.Bo,
			activation,
		)
		if err != nil {
			fmt.Printf("error computing yHat: %v", err)
		}

		mse := meanSquaredError(nn.YTrain, yHat)
		fmt.Println(mse)

	}

}

func computeOutput(arr, w, b *mat.Dense, activationFunc func(float64) float64) (*mat.Dense, error) {
	// Check if any of the input matrices is nil
	if arr == nil || w == nil || b == nil {
		return nil, fmt.Errorf("Input matrices cannot be nil")
	}

	// Check input dimensions
	arrRows, arrCols := arr.Dims()
	wRows, wCols := w.Dims()
	bRows, bCols := b.Dims()

	if arrCols != wRows || bCols != wCols {
		return nil, fmt.Errorf("Matrix dimension mismatch: arr[%d, %d], w[%d, %d], b[%d, %d]", arrRows, arrCols, wRows, wCols, bRows, bCols)
	}

	// Compute the dot product between the input matrix 'arr' and the weight matrix 'w'
	var product mat.Dense
	product.Mul(arr, w)

	// Check dimensions of product and bias
	productRows, productCols := product.Dims()
	if productCols != bCols {
		return nil, fmt.Errorf("Matrix dimension mismatch: product[%d, %d], b[%d, %d]", productRows, productCols, bRows, bCols)
	}

	// Add the bias matrix 'b' to the product
	var result mat.Dense
	result.Add(&product, b)

	// Apply the activation function to the result
	applyActivation(&result, activationFunc)

	return &result, nil
}

func applyActivation(m *mat.Dense, f func(float64) float64) {
	r, c := m.Dims()
	data := m.RawMatrix().Data
	for i := 0; i < r*c; i++ {
		data[i] = f(data[i])
	}
}

func meanSquaredError(y, yHat *mat.Dense) float64 {
	var sum float64
	r, c := y.Dims()

	for row := 0; row < r; row++ {
		for col := 0; col < c; col++ {
			diff := y.At(row, col) - yHat.At(row, col)
			sum += (diff * diff)
		}
	}
	return sum / float64((r * c))
}