Jensen-holm commited on
Commit
554811e
·
1 Parent(s): 27c0b63

added trainTestSplit() function

Browse files
go.mod CHANGED
@@ -2,10 +2,13 @@ module github.com/Jensen-holm/ml-from-scratch
2
 
3
  go 1.19
4
 
 
 
 
 
 
5
  require (
6
  github.com/andybalholm/brotli v1.0.5 // indirect
7
- github.com/go-gota/gota v0.12.0 // indirect
8
- github.com/gofiber/fiber/v2 v2.49.2 // indirect
9
  github.com/google/uuid v1.3.1 // indirect
10
  github.com/klauspost/compress v1.16.7 // indirect
11
  github.com/mattn/go-colorable v0.1.13 // indirect
 
2
 
3
  go 1.19
4
 
5
+ require (
6
+ github.com/go-gota/gota v0.12.0
7
+ github.com/gofiber/fiber/v2 v2.49.2
8
+ )
9
+
10
  require (
11
  github.com/andybalholm/brotli v1.0.5 // indirect
 
 
12
  github.com/google/uuid v1.3.1 // indirect
13
  github.com/klauspost/compress v1.16.7 // indirect
14
  github.com/mattn/go-colorable v0.1.13 // indirect
go.sum CHANGED
@@ -53,6 +53,7 @@ golang.org/x/exp v0.0.0-20180321215751-8460e604b9de/go.mod h1:CJ0aWSM057203Lf6IL
53
  golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
54
  golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
55
  golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
 
56
  golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
57
  golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
58
  golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
@@ -67,7 +68,6 @@ golang.org/x/mobile v0.0.0-20190719004257-d2bd2a29d028/go.mod h1:E/iHnbuqvinMTCc
67
  golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
68
  golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
69
  golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
70
- golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6 h1:0PC75Fz/kyMGhL0e1QnypqK2kQMqKt9csD1GnMJR+Zk=
71
  golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
72
  golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
73
  golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
@@ -95,6 +95,7 @@ gonum.org/v1/gonum v0.0.0-20180816165407-929014505bf4/go.mod h1:Y+Yx5eoAFn32cQvJ
95
  gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
96
  gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
97
  gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
 
98
  gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
99
  gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
100
  gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY=
 
53
  golang.org/x/exp v0.0.0-20180807140117-3d87b88a115f/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
54
  golang.org/x/exp v0.0.0-20190125153040-c74c464bbbf2/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
55
  golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
56
+ golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3 h1:n9HxLrNxWWtEb1cA950nuEEj3QnKbtsCJ6KjcgisNUs=
57
  golang.org/x/exp v0.0.0-20191002040644-a1355ae1e2c3/go.mod h1:NOZ3BPKG0ec/BKJQgnvsSFpcKLM5xXVWnvZS97DWHgE=
58
  golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs=
59
  golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
 
68
  golang.org/x/mod v0.1.0/go.mod h1:0QHyrYULN0/3qlju5TqG8bIK38QM8yzMo5ekMj3DlcY=
69
  golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
70
  golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
 
71
  golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
72
  golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ=
73
  golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc=
 
95
  gonum.org/v1/gonum v0.8.2/go.mod h1:oe/vMfY3deqTw+1EZJhuvEW2iwGF1bW9wwu7XCu0+v0=
96
  gonum.org/v1/gonum v0.9.1 h1:HCWmqqNoELL0RAQeKBXWtkp04mGk8koafcB4He6+uhc=
97
  gonum.org/v1/gonum v0.9.1/go.mod h1:TZumC3NeyVQskjXqmyWt4S3bINhy7B4eYwW69EbyX+0=
98
+ gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0 h1:OE9mWmgKkjJyEmDAAtGMPjXu+YNeGvK9VTSHY6+Qihc=
99
  gonum.org/v1/netlib v0.0.0-20190313105609-8cb42192e0e0/go.mod h1:wa6Ws7BG/ESfp6dHfk7C6KdzKA7wR7u/rKwOGE66zvw=
100
  gonum.org/v1/plot v0.0.0-20190515093506-e2840ee46a6b/go.mod h1:Wt8AAjI+ypCyYX3nZBvf6cAIx93T+c/OS2HFAYskSZc=
101
  gonum.org/v1/plot v0.9.0/go.mod h1:3Pcqqmp6RHvJI72kgb8fThyUnav364FOsdDo2aGW5lY=
nn/.main.go.swp ADDED
Binary file (12.3 kB). View file
 
nn/.split.go.swp ADDED
Binary file (12.3 kB). View file
 
nn/main.go CHANGED
@@ -16,7 +16,13 @@ type NN struct {
16
  HiddenSize int `json:"hidden_size"`
17
  LearningRate float64 `json:"learning_rate"`
18
  ActivationFunc string `json:"activation"`
19
- Df dataframe.DataFrame
 
 
 
 
 
 
20
  }
21
 
22
  func NewNN(c *fiber.Ctx) (*NN, error) {
@@ -25,6 +31,16 @@ func NewNN(c *fiber.Ctx) (*NN, error) {
25
  if err != nil {
26
  return nil, fmt.Errorf("invalid JSON data: %v", err)
27
  }
28
- newNN.Df = dataframe.ReadCSV(strings.NewReader(newNN.CSVData))
 
29
  return newNN, nil
30
  }
 
 
 
 
 
 
 
 
 
 
16
  HiddenSize int `json:"hidden_size"`
17
  LearningRate float64 `json:"learning_rate"`
18
  ActivationFunc string `json:"activation"`
19
+ TestSize float64 `json:"test_size"`
20
+
21
+ Df *dataframe.DataFrame
22
+ XTrain dataframe.DataFrame
23
+ YTrain dataframe.DataFrame
24
+ XTest dataframe.DataFrame
25
+ YTest dataframe.DataFrame
26
  }
27
 
28
  func NewNN(c *fiber.Ctx) (*NN, error) {
 
31
  if err != nil {
32
  return nil, fmt.Errorf("invalid JSON data: %v", err)
33
  }
34
+ df := dataframe.ReadCSV(strings.NewReader(newNN.CSVData))
35
+ newNN.Df = &df
36
  return newNN, nil
37
  }
38
+
39
+ func (nn *NN) Train() {
40
+ // train test split the data
41
+
42
+ // iterate n times where n = nn.Epochs
43
+ // use backprop algorithm on each iteration
44
+ // to fit the model to the data
45
+
46
+ }
nn/split.go CHANGED
@@ -1,5 +1,38 @@
1
  package nn
2
 
 
 
 
 
 
3
  // implement train test split function
4
 
5
- func trainTestSplit() {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  package nn
2
 
3
+ import (
4
+ "math"
5
+ "math/rand"
6
+ )
7
+
8
  // implement train test split function
9
 
10
+ func (nn *NN) trainTestSplit() {
11
+ // now we split the data into training
12
+ // and testing based on user specified
13
+ // nn.TestSize.
14
+ nRows := nn.Df.Nrow()
15
+ testRows := int(math.Floor(float64(nRows) * nn.TestSize))
16
+
17
+ // subset the testing data
18
+ // randomly select trainRows number of rows
19
+ randStrt := rand.Intn(int(math.Floor(float64(nRows) * nn.TestSize)))
20
+ test := nn.Df.Subset([]int{randStrt, randStrt + testRows})
21
+
22
+ // use what is left for training
23
+ allIndices := make([]int, nRows)
24
+ for i := range allIndices {
25
+ allIndices[i] = i
26
+ }
27
+
28
+ // Remove the test indices using slice append and variadic parameter
29
+ trainIndices := append(allIndices[:randStrt], allIndices[randStrt+testRows:]...)
30
+
31
+ // Create the train DataFrame using the trainIndices
32
+ train := nn.Df.Subset(trainIndices)
33
+
34
+ nn.XTrain = train.Select(nn.Features)
35
+ nn.YTrain = train.Select(nn.Target)
36
+ nn.XTest = test.Select(nn.Features)
37
+ nn.YTest = test.Select(nn.Target)
38
+ }
wrangle/.trainTestSplit.go.swp ADDED
Binary file (12.3 kB). View file
 
wrangle/trainTestSplit.go ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ package wrangle
2
+
3
+ import (
4
+ "github.com/go-gota/gota/dataframe"
5
+ )
6
+
7
+ func TrainTestSplit(df *dataframe.DataFrame, test float64) *dataframe.DataFrame {
8
+
9
+ return
10
+ }