package nn import ( "fmt" "gonum.org/v1/gonum/mat" ) func (nn *NN) Backprop() { var ( activation = *nn.ActivationFunc // lossHist []float64 ) for i := 0; i < nn.Epochs; i++ { // compute output with current w + b // then compute loss & backprop hiddenOutput, err := computeOutput( nn.XTrain, nn.Wh, nn.Bh, activation, ) if err != nil { fmt.Printf("error computing hidden output: %v", err) } yHat, err := computeOutput( hiddenOutput, nn.Wo, nn.Bo, activation, ) if err != nil { fmt.Printf("error computing yHat: %v", err) } mse := meanSquaredError(nn.YTrain, yHat) fmt.Println(mse) } } func computeOutput(arr, w, b *mat.Dense, activationFunc func(float64) float64) (*mat.Dense, error) { // Check if any of the input matrices is nil if arr == nil || w == nil || b == nil { return nil, fmt.Errorf("Input matrices cannot be nil") } // Check input dimensions arrRows, arrCols := arr.Dims() wRows, wCols := w.Dims() bRows, bCols := b.Dims() if arrCols != wRows || bCols != wCols { return nil, fmt.Errorf("Matrix dimension mismatch: arr[%d, %d], w[%d, %d], b[%d, %d]", arrRows, arrCols, wRows, wCols, bRows, bCols) } // Compute the dot product between the input matrix 'arr' and the weight matrix 'w' var product mat.Dense product.Mul(arr, w) // Check dimensions of product and bias productRows, productCols := product.Dims() if productCols != bCols { return nil, fmt.Errorf("Matrix dimension mismatch: product[%d, %d], b[%d, %d]", productRows, productCols, bRows, bCols) } // Add the bias matrix 'b' to the product var result mat.Dense result.Add(&product, b) // Apply the activation function to the result applyActivation(&result, activationFunc) return &result, nil } func applyActivation(m *mat.Dense, f func(float64) float64) { r, c := m.Dims() data := m.RawMatrix().Data for i := 0; i < r*c; i++ { data[i] = f(data[i]) } } func meanSquaredError(y, yHat *mat.Dense) float64 { var sum float64 r, c := y.Dims() for row := 0; row < r; row++ { for col := 0; col < c; col++ { diff := y.At(row, col) - yHat.At(row, col) sum += (diff * diff) } } return sum / float64((r * c)) }