Jensen-holm commited on
Commit
eec3b88
·
1 Parent(s): a4aa491

significantly decreased the complexity of the API by deciding to make

Browse files

seperate endpoints for each kind of model moving forward, and also
skipping the requet struct we had made and going straight to the
algorithm.

Files changed (5) hide show
  1. alg/{alg.go → interface.go} +0 -0
  2. example/main.go +15 -18
  3. nn/main.go +18 -54
  4. request/request.go +0 -29
  5. server.go +5 -20
alg/{alg.go → interface.go} RENAMED
File without changes
example/main.go CHANGED
@@ -9,13 +9,11 @@ import (
9
  )
10
 
11
  type RequestPayload struct {
12
- CSVData string `json:"csv_data"`
13
- Features []string `json:"features"`
14
- Target string `json:"target"`
15
- Epochs int `json:"epochs"`
16
- HiddenSize int `json:"hidden_size"`
17
- LearningRate float64 `json:"learning_rate"`
18
- ActivationFunc string `json:"activation_func"`
19
  }
20
 
21
  func main() {
@@ -30,19 +28,18 @@ func main() {
30
  csvString := string(csvBytes)
31
  features := []string{"petal length", "sepal length", "sepal width", "petal width"}
32
  target := "species"
33
- epochs := 100
34
- hiddenSize := 8
35
- learningRate := 0.1
36
- activationFunc := "tanh"
 
 
37
 
38
  payload := RequestPayload{
39
- CSVData: csvString,
40
- Features: features,
41
- Target: target,
42
- Epochs: epochs,
43
- HiddenSize: hiddenSize,
44
- LearningRate: learningRate,
45
- ActivationFunc: activationFunc,
46
  }
47
 
48
  jsonPayload, err := json.Marshal(payload)
 
9
  )
10
 
11
  type RequestPayload struct {
12
+ CSVData string `json:"csv_data"`
13
+ Features []string `json:"features"`
14
+ Target string `json:"target"`
15
+ Algorithm string `json:"algorithm"`
16
+ Args map[string]interface{}
 
 
17
  }
18
 
19
  func main() {
 
28
  csvString := string(csvBytes)
29
  features := []string{"petal length", "sepal length", "sepal width", "petal width"}
30
  target := "species"
31
+ args := map[string]interface{}{
32
+ "epochs": 100,
33
+ "hidden_size": 8,
34
+ "learning_rate": 0.1,
35
+ "activation": "tanh",
36
+ }
37
 
38
  payload := RequestPayload{
39
+ CSVData: csvString,
40
+ Features: features,
41
+ Target: target,
42
+ Args: args,
 
 
 
43
  }
44
 
45
  jsonPayload, err := json.Marshal(payload)
nn/main.go CHANGED
@@ -2,65 +2,29 @@ package nn
2
 
3
  import (
4
  "fmt"
 
5
 
6
- "github.com/Jensen-holm/ml-from-scratch/alg"
7
- "github.com/Jensen-holm/ml-from-scratch/request"
8
  )
9
 
10
  type NN struct {
11
- alg.Alg
12
- args *NNArgs
 
 
 
 
 
 
13
  }
14
 
15
- func New(rp *request.Payload) (alg.Alg, error) {
16
- // parse the args and make a new NN struct
17
- nnArgs := make(map[string]interface{})
18
- params := []string{
19
- "epochs",
20
- "activation",
21
- "hidden_size",
22
- "learning_rate",
23
  }
24
-
25
- for _, param := range params {
26
- i, isIn := rp.Args[param]
27
- if !isIn {
28
- return nil, fmt.Errorf("user must specify %s", param)
29
- }
30
-
31
- switch param {
32
- case "epochs":
33
- if val, ok := i.(int); ok {
34
- nnArgs[param] = val
35
- } else {
36
- return nil, fmt.Errorf("expected %s to be an int", param)
37
- }
38
- case "activation":
39
- if val, ok := i.(string); ok {
40
- nnArgs[param] = ActivationMap[val]
41
- } else {
42
- return nil, fmt.Errorf("expected %s to be a string", param)
43
- }
44
- case "hidden_size":
45
- if val, ok := i.(int); ok {
46
- nnArgs[param] = val
47
- } else {
48
- return nil, fmt.Errorf("expected %s to be an int", param)
49
- }
50
- case "learning_rate":
51
- if val, ok := i.(float64); ok {
52
- nnArgs[param] = val
53
- } else {
54
- return nil, fmt.Errorf("expected %s to be a float64", param)
55
- }
56
- default:
57
- return nil, fmt.Errorf("unsupported parameter: %s", param)
58
- }
59
- }
60
-
61
- args := NewArgs(nnArgs)
62
- return &NN{
63
- args: args,
64
- }, nil
65
-
66
  }
 
2
 
3
  import (
4
  "fmt"
5
+ "strings"
6
 
7
+ "github.com/go-gota/gota/dataframe"
8
+ "github.com/gofiber/fiber/v2"
9
  )
10
 
11
  type NN struct {
12
+ CSVData string `json:"csv_data"`
13
+ Features []string `json:"features"`
14
+ Target string `json:"target"`
15
+ Epochs int `json:"epochs"`
16
+ HiddenSize int `json:"hidden_size"`
17
+ LearningRate float64 `json:"learning_rate"`
18
+ ActivationFunc string `json:"activation"`
19
+ Df dataframe.DataFrame
20
  }
21
 
22
+ func NewNN(c *fiber.Ctx) (*NN, error) {
23
+ newNN := new(NN)
24
+ err := c.BodyParser(newNN)
25
+ if err != nil {
26
+ return nil, fmt.Errorf("invalid JSON data: %v", err)
 
 
 
27
  }
28
+ newNN.Df = dataframe.ReadCSV(strings.NewReader(newNN.CSVData))
29
+ return newNN, nil
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  }
request/request.go DELETED
@@ -1,29 +0,0 @@
1
- package request
2
-
3
- import (
4
- "fmt"
5
-
6
- "github.com/go-gota/gota/dataframe"
7
- "github.com/gofiber/fiber/v2"
8
- )
9
-
10
- type Payload struct {
11
- CSVData string `json:"csv_data"`
12
- Features []string `json:"features"`
13
- Target string `json:"target"`
14
- Algorithm string `json:"algorithm"`
15
-
16
- Args map[string]interface{} `json:"args"`
17
- Df dataframe.DataFrame
18
- }
19
-
20
- func (p *Payload) SetDf(df dataframe.DataFrame) {
21
- p.Df = df
22
- }
23
-
24
- func NewPayload(dest *Payload, c *fiber.Ctx) error {
25
- if err := c.BodyParser(dest); err != nil {
26
- return fmt.Errorf("invalid JSON data")
27
- }
28
- return nil
29
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
server.go CHANGED
@@ -1,41 +1,26 @@
1
  package main
2
 
3
  import (
4
- "strings"
5
-
6
- "github.com/Jensen-holm/ml-from-scratch/alg"
7
  "github.com/Jensen-holm/ml-from-scratch/nn"
8
- "github.com/Jensen-holm/ml-from-scratch/request"
9
- "github.com/go-gota/gota/dataframe"
10
  "github.com/gofiber/fiber/v2"
11
  )
12
 
13
  func main() {
14
  app := fiber.New()
15
 
16
- algMap := map[string]alg.Alg{
17
- "neural_network": nn.NN,
18
- }
19
-
20
- app.Post("/", func(c *fiber.Ctx) error {
21
- r := new(request.Payload)
22
 
23
- err := request.NewPayload(r, c)
24
  if err != nil {
25
  return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
26
- "error": "Invalid JSON data",
27
  })
28
  }
29
 
30
- df := dataframe.ReadCSV(strings.NewReader(r.CSVData))
31
- r.SetDf(df)
32
-
33
- a := algMap[r.Algorithm]
34
- a.New(r)
35
-
36
  return c.SendString("No error")
37
  })
38
 
39
  app.Listen(":3000")
40
-
41
  }
 
1
  package main
2
 
3
  import (
 
 
 
4
  "github.com/Jensen-holm/ml-from-scratch/nn"
 
 
5
  "github.com/gofiber/fiber/v2"
6
  )
7
 
8
  func main() {
9
  app := fiber.New()
10
 
11
+ // eventually we might want to add a key to this endpoint
12
+ // that we will be able to validate.
13
+ app.Post("/neural-network", func(c *fiber.Ctx) error {
 
 
 
14
 
15
+ nn, err := nn.NewNN(c)
16
  if err != nil {
17
  return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
18
+ "error": err,
19
  })
20
  }
21
 
 
 
 
 
 
 
22
  return c.SendString("No error")
23
  })
24
 
25
  app.Listen(":3000")
 
26
  }