Spaces:
Sleeping
Sleeping
Commit
·
554811e
1
Parent(s):
27c0b63
added trainTestSplit() function
Browse files- go.mod +5 -2
- go.sum +2 -1
- nn/.main.go.swp +0 -0
- nn/.split.go.swp +0 -0
- nn/main.go +18 -2
- nn/split.go +34 -1
- wrangle/.trainTestSplit.go.swp +0 -0
- wrangle/trainTestSplit.go +10 -0
go.mod
CHANGED
@@ -2,10 +2,13 @@ module github.com/Jensen-holm/ml-from-scratch
|
|
2 |
|
3 |
go 1.19
|
4 |
|
|
|
|
|
|
|
|
|
|
|
5 |
require (
|
6 |
github.com/andybalholm/brotli v1.0.5 // indirect
|
7 |
-
github.com/go-gota/gota v0.12.0 // indirect
|
8 |
-
github.com/gofiber/fiber/v2 v2.49.2 // indirect
|
9 |
github.com/google/uuid v1.3.1 // indirect
|
10 |
github.com/klauspost/compress v1.16.7 // indirect
|
11 |
github.com/mattn/go-colorable v0.1.13 // indirect
|
|
|
2 |
|
3 |
go 1.19
|
4 |
|
5 |
+
require (
|
6 |
+
github.com/go-gota/gota v0.12.0
|
7 |
+
github.com/gofiber/fiber/v2 v2.49.2
|
8 |
+
)
|
9 |
+
|
10 |
require (
|
11 |
github.com/andybalholm/brotli v1.0.5 // indirect
|
|
|
|
|
12 |
github.com/google/uuid v1.3.1 // indirect
|
13 |
github.com/klauspost/compress v1.16.7 // indirect
|
14 |
github.com/mattn/go-colorable v0.1.13 // indirect
|
go.sum
CHANGED
@@ -53,6 +53,7 @@ golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL
|
|
53 |
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
54 |
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
55 |
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
|
|
56 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
57 |
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
58 |
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
@@ -67,7 +68,6 @@ golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCc
|
|
67 |
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
|
68 |
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
69 |
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
70 |
-
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 h1:0PC75Fz/kyMGhL0e1QnypqK2kQMqKt9csD1GnMJR+Zk=
|
71 |
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
72 |
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
|
73 |
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
@@ -95,6 +95,7 @@ gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJ
|
|
95 |
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
|
96 |
gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
|
97 |
gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
|
|
|
98 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
|
99 |
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
|
100 |
gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY=
|
|
|
53 |
golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
54 |
golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
55 |
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
56 |
+
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
|
57 |
golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
|
58 |
golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
|
59 |
golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
|
|
|
68 |
golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
|
69 |
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
70 |
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
|
|
71 |
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
|
72 |
golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
|
73 |
golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
|
|
|
95 |
gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
|
96 |
gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
|
97 |
gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
|
98 |
+
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
|
99 |
gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
|
100 |
gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
|
101 |
gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY=
|
nn/.main.go.swp
ADDED
Binary file (12.3 kB). View file
|
|
nn/.split.go.swp
ADDED
Binary file (12.3 kB). View file
|
|
nn/main.go
CHANGED
@@ -16,7 +16,13 @@ type NN struct {
|
|
16 |
HiddenSize int `json:"hidden_size"`
|
17 |
LearningRate float64 `json:"learning_rate"`
|
18 |
ActivationFunc string `json:"activation"`
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
}
|
21 |
|
22 |
func NewNN(c *fiber.Ctx) (*NN, error) {
|
@@ -25,6 +31,16 @@ func NewNN(c *fiber.Ctx) (*NN, error) {
|
|
25 |
if err != nil {
|
26 |
return nil, fmt.Errorf("invalid JSON data: %v", err)
|
27 |
}
|
28 |
-
|
|
|
29 |
return newNN, nil
|
30 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
HiddenSize int `json:"hidden_size"`
|
17 |
LearningRate float64 `json:"learning_rate"`
|
18 |
ActivationFunc string `json:"activation"`
|
19 |
+
TestSize float64 `json:"test_size"`
|
20 |
+
|
21 |
+
Df *dataframe.DataFrame
|
22 |
+
XTrain dataframe.DataFrame
|
23 |
+
YTrain dataframe.DataFrame
|
24 |
+
XTest dataframe.DataFrame
|
25 |
+
YTest dataframe.DataFrame
|
26 |
}
|
27 |
|
28 |
func NewNN(c *fiber.Ctx) (*NN, error) {
|
|
|
31 |
if err != nil {
|
32 |
return nil, fmt.Errorf("invalid JSON data: %v", err)
|
33 |
}
|
34 |
+
df := dataframe.ReadCSV(strings.NewReader(newNN.CSVData))
|
35 |
+
newNN.Df = &df
|
36 |
return newNN, nil
|
37 |
}
|
38 |
+
|
39 |
+
func (nn *NN) Train() {
|
40 |
+
// train test split the data
|
41 |
+
|
42 |
+
// iterate n times where n = nn.Epochs
|
43 |
+
// use backprop algorithm on each iteration
|
44 |
+
// to fit the model to the data
|
45 |
+
|
46 |
+
}
|
nn/split.go
CHANGED
@@ -1,5 +1,38 @@
|
|
1 |
package nn
|
2 |
|
|
|
|
|
|
|
|
|
|
|
3 |
// implement train test split function
|
4 |
|
5 |
-
func trainTestSplit() {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
package nn
|
2 |
|
3 |
+
import (
|
4 |
+
"math"
|
5 |
+
"math/rand"
|
6 |
+
)
|
7 |
+
|
8 |
// implement train test split function
|
9 |
|
10 |
+
func (nn *NN) trainTestSplit() {
|
11 |
+
// now we split the data into training
|
12 |
+
// and testing based on user specified
|
13 |
+
// nn.TestSize.
|
14 |
+
nRows := nn.Df.Nrow()
|
15 |
+
testRows := int(math.Floor(float64(nRows) * nn.TestSize))
|
16 |
+
|
17 |
+
// subset the testing data
|
18 |
+
// randomly select trainRows number of rows
|
19 |
+
randStrt := rand.Intn(int(math.Floor(float64(nRows) * nn.TestSize)))
|
20 |
+
test := nn.Df.Subset([]int{randStrt, randStrt + testRows})
|
21 |
+
|
22 |
+
// use what is left for training
|
23 |
+
allIndices := make([]int, nRows)
|
24 |
+
for i := range allIndices {
|
25 |
+
allIndices[i] = i
|
26 |
+
}
|
27 |
+
|
28 |
+
// Remove the test indices using slice append and variadic parameter
|
29 |
+
trainIndices := append(allIndices[:randStrt], allIndices[randStrt+testRows:]...)
|
30 |
+
|
31 |
+
// Create the train DataFrame using the trainIndices
|
32 |
+
train := nn.Df.Subset(trainIndices)
|
33 |
+
|
34 |
+
nn.XTrain = train.Select(nn.Features)
|
35 |
+
nn.YTrain = train.Select(nn.Target)
|
36 |
+
nn.XTest = test.Select(nn.Features)
|
37 |
+
nn.YTest = test.Select(nn.Target)
|
38 |
+
}
|
wrangle/.trainTestSplit.go.swp
ADDED
Binary file (12.3 kB). View file
|
|
wrangle/trainTestSplit.go
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
package wrangle
|
2 |
+
|
3 |
+
import (
|
4 |
+
"github.com/go-gota/gota/dataframe"
|
5 |
+
)
|
6 |
+
|
7 |
+
func TrainTestSplit(df *dataframe.DataFrame, test float64) *dataframe.DataFrame {
|
8 |
+
|
9 |
+
return
|
10 |
+
}
|