Jensen-holm commited on
Commit
a4aa491
·
1 Parent(s): dfd3544

re configured the Alg interface and the NN sttuct, now we have to work

Browse files
Files changed (6) hide show
  1. alg/alg.go +1 -0
  2. nn/activation.go +19 -0
  3. nn/args.go +17 -0
  4. nn/main.go +57 -2
  5. request/request.go +6 -8
  6. server.go +8 -2
alg/alg.go CHANGED
@@ -1,5 +1,6 @@
1
  package alg
2
 
3
  type Alg interface {
 
4
  Train()
5
  }
 
1
  package alg
2
 
3
  type Alg interface {
4
+ New()
5
  Train()
6
  }
nn/activation.go ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package nn
2
+
3
+ var ActivationMap = map[string]func(){
4
+ "sigmoid": Sigmoid,
5
+ "tanh": Tanh,
6
+ "relu": Relu,
7
+ }
8
+
9
+ func Sigmoid() {}
10
+
11
+ func SigmoidPrime() {}
12
+
13
+ func Tanh() {}
14
+
15
+ func TanhPrime() {}
16
+
17
+ func Relu() {}
18
+
19
+ func ReluPrime() {}
nn/args.go ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ package nn
2
+
3
+ type NNArgs struct {
4
+ epochs int
5
+ hiddenSize int
6
+ learningRate float64
7
+ activationFunc func()
8
+ }
9
+
10
+ func NewArgs(argsMap map[string]interface{}) *NNArgs {
11
+ return &NNArgs{
12
+ epochs: argsMap["epochs"].(int),
13
+ hiddenSize: argsMap["hidden_size"].(int),
14
+ learningRate: argsMap["learning_rate"].(float64),
15
+ activationFunc: argsMap["activation"].(func()),
16
+ }
17
+ }
nn/main.go CHANGED
@@ -1,11 +1,66 @@
1
  package nn
2
 
3
- import "github.com/Jensen-holm/ml-from-scratch/alg"
 
 
 
 
 
4
 
5
  type NN struct {
6
  alg.Alg
 
7
  }
8
 
9
- func New(rp *RequestPayload) {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  }
 
1
  package nn
2
 
3
+ import (
4
+ "fmt"
5
+
6
+ "github.com/Jensen-holm/ml-from-scratch/alg"
7
+ "github.com/Jensen-holm/ml-from-scratch/request"
8
+ )
9
 
10
  type NN struct {
11
  alg.Alg
12
+ args *NNArgs
13
  }
14
 
15
+ func New(rp *request.Payload) (alg.Alg, error) {
16
+ // parse the args and make a new NN struct
17
+ nnArgs := make(map[string]interface{})
18
+ params := []string{
19
+ "epochs",
20
+ "activation",
21
+ "hidden_size",
22
+ "learning_rate",
23
+ }
24
+
25
+ for _, param := range params {
26
+ i, isIn := rp.Args[param]
27
+ if !isIn {
28
+ return nil, fmt.Errorf("user must specify %s", param)
29
+ }
30
+
31
+ switch param {
32
+ case "epochs":
33
+ if val, ok := i.(int); ok {
34
+ nnArgs[param] = val
35
+ } else {
36
+ return nil, fmt.Errorf("expected %s to be an int", param)
37
+ }
38
+ case "activation":
39
+ if val, ok := i.(string); ok {
40
+ nnArgs[param] = ActivationMap[val]
41
+ } else {
42
+ return nil, fmt.Errorf("expected %s to be a string", param)
43
+ }
44
+ case "hidden_size":
45
+ if val, ok := i.(int); ok {
46
+ nnArgs[param] = val
47
+ } else {
48
+ return nil, fmt.Errorf("expected %s to be an int", param)
49
+ }
50
+ case "learning_rate":
51
+ if val, ok := i.(float64); ok {
52
+ nnArgs[param] = val
53
+ } else {
54
+ return nil, fmt.Errorf("expected %s to be a float64", param)
55
+ }
56
+ default:
57
+ return nil, fmt.Errorf("unsupported parameter: %s", param)
58
+ }
59
+ }
60
+
61
+ args := NewArgs(nnArgs)
62
+ return &NN{
63
+ args: args,
64
+ }, nil
65
 
66
  }
request/request.go CHANGED
@@ -8,15 +8,13 @@ import (
8
  )
9
 
10
  type Payload struct {
11
- CSVData string `json:"csv_data"`
12
- Features []string `json:"features"`
13
- Target string `json:"target"`
14
- Epochs int `json:"epochs"`
15
- HiddenSize int `json:"hidden_size"`
16
- LearningRate float64 `json:"learning_rate"`
17
- ActivationFunc string `json:"activation_func"`
18
 
19
- Df dataframe.DataFrame
 
20
  }
21
 
22
  func (p *Payload) SetDf(df dataframe.DataFrame) {
 
8
  )
9
 
10
  type Payload struct {
11
+ CSVData string `json:"csv_data"`
12
+ Features []string `json:"features"`
13
+ Target string `json:"target"`
14
+ Algorithm string `json:"algorithm"`
 
 
 
15
 
16
+ Args map[string]interface{} `json:"args"`
17
+ Df dataframe.DataFrame
18
  }
19
 
20
  func (p *Payload) SetDf(df dataframe.DataFrame) {
server.go CHANGED
@@ -1,9 +1,10 @@
1
  package main
2
 
3
  import (
4
- "fmt"
5
  "strings"
6
 
 
 
7
  "github.com/Jensen-holm/ml-from-scratch/request"
8
  "github.com/go-gota/gota/dataframe"
9
  "github.com/gofiber/fiber/v2"
@@ -12,6 +13,10 @@ import (
12
  func main() {
13
  app := fiber.New()
14
 
 
 
 
 
15
  app.Post("/", func(c *fiber.Ctx) error {
16
  r := new(request.Payload)
17
 
@@ -25,7 +30,8 @@ func main() {
25
  df := dataframe.ReadCSV(strings.NewReader(r.CSVData))
26
  r.SetDf(df)
27
 
28
- fmt.Println(r.Df)
 
29
 
30
  return c.SendString("No error")
31
  })
 
1
  package main
2
 
3
  import (
 
4
  "strings"
5
 
6
+ "github.com/Jensen-holm/ml-from-scratch/alg"
7
+ "github.com/Jensen-holm/ml-from-scratch/nn"
8
  "github.com/Jensen-holm/ml-from-scratch/request"
9
  "github.com/go-gota/gota/dataframe"
10
  "github.com/gofiber/fiber/v2"
 
13
  func main() {
14
  app := fiber.New()
15
 
16
+ algMap := map[string]alg.Alg{
17
+ "neural_network": nn.NN,
18
+ }
19
+
20
  app.Post("/", func(c *fiber.Ctx) error {
21
  r := new(request.Payload)
22
 
 
30
  df := dataframe.ReadCSV(strings.NewReader(r.CSVData))
31
  r.SetDf(df)
32
 
33
+ a := algMap[r.Algorithm]
34
+ a.New(r)
35
 
36
  return c.SendString("No error")
37
  })