Spaces:
Sleeping
Sleeping
Commit
·
eec3b88
1
Parent(s):
a4aa491
significantly decreased the complexity of the API by deciding to make
Browse filesseperate endpoints for each kind of model moving forward, and also
skipping the requet struct we had made and going straight to the
algorithm.
- alg/{alg.go → interface.go} +0 -0
- example/main.go +15 -18
- nn/main.go +18 -54
- request/request.go +0 -29
- server.go +5 -20
alg/{alg.go → interface.go}
RENAMED
File without changes
|
example/main.go
CHANGED
@@ -9,13 +9,11 @@ import (
|
|
9 |
)
|
10 |
|
11 |
type RequestPayload struct {
|
12 |
-
CSVData
|
13 |
-
Features
|
14 |
-
Target
|
15 |
-
|
16 |
-
|
17 |
-
LearningRate float64 `json:"learning_rate"`
|
18 |
-
ActivationFunc string `json:"activation_func"`
|
19 |
}
|
20 |
|
21 |
func main() {
|
@@ -30,19 +28,18 @@ func main() {
|
|
30 |
csvString := string(csvBytes)
|
31 |
features := []string{"petal length", "sepal length", "sepal width", "petal width"}
|
32 |
target := "species"
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
37 |
|
38 |
payload := RequestPayload{
|
39 |
-
CSVData:
|
40 |
-
Features:
|
41 |
-
Target:
|
42 |
-
|
43 |
-
HiddenSize: hiddenSize,
|
44 |
-
LearningRate: learningRate,
|
45 |
-
ActivationFunc: activationFunc,
|
46 |
}
|
47 |
|
48 |
jsonPayload, err := json.Marshal(payload)
|
|
|
9 |
)
|
10 |
|
11 |
type RequestPayload struct {
|
12 |
+
CSVData string `json:"csv_data"`
|
13 |
+
Features []string `json:"features"`
|
14 |
+
Target string `json:"target"`
|
15 |
+
Algorithm string `json:"algorithm"`
|
16 |
+
Args map[string]interface{}
|
|
|
|
|
17 |
}
|
18 |
|
19 |
func main() {
|
|
|
28 |
csvString := string(csvBytes)
|
29 |
features := []string{"petal length", "sepal length", "sepal width", "petal width"}
|
30 |
target := "species"
|
31 |
+
args := map[string]interface{}{
|
32 |
+
"epochs": 100,
|
33 |
+
"hidden_size": 8,
|
34 |
+
"learning_rate": 0.1,
|
35 |
+
"activation": "tanh",
|
36 |
+
}
|
37 |
|
38 |
payload := RequestPayload{
|
39 |
+
CSVData: csvString,
|
40 |
+
Features: features,
|
41 |
+
Target: target,
|
42 |
+
Args: args,
|
|
|
|
|
|
|
43 |
}
|
44 |
|
45 |
jsonPayload, err := json.Marshal(payload)
|
nn/main.go
CHANGED
@@ -2,65 +2,29 @@ package nn
|
|
2 |
|
3 |
import (
|
4 |
"fmt"
|
|
|
5 |
|
6 |
-
"github.com/
|
7 |
-
"github.com/
|
8 |
)
|
9 |
|
10 |
type NN struct {
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
}
|
14 |
|
15 |
-
func
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
"
|
20 |
-
"activation",
|
21 |
-
"hidden_size",
|
22 |
-
"learning_rate",
|
23 |
}
|
24 |
-
|
25 |
-
|
26 |
-
i, isIn := rp.Args[param]
|
27 |
-
if !isIn {
|
28 |
-
return nil, fmt.Errorf("user must specify %s", param)
|
29 |
-
}
|
30 |
-
|
31 |
-
switch param {
|
32 |
-
case "epochs":
|
33 |
-
if val, ok := i.(int); ok {
|
34 |
-
nnArgs[param] = val
|
35 |
-
} else {
|
36 |
-
return nil, fmt.Errorf("expected %s to be an int", param)
|
37 |
-
}
|
38 |
-
case "activation":
|
39 |
-
if val, ok := i.(string); ok {
|
40 |
-
nnArgs[param] = ActivationMap[val]
|
41 |
-
} else {
|
42 |
-
return nil, fmt.Errorf("expected %s to be a string", param)
|
43 |
-
}
|
44 |
-
case "hidden_size":
|
45 |
-
if val, ok := i.(int); ok {
|
46 |
-
nnArgs[param] = val
|
47 |
-
} else {
|
48 |
-
return nil, fmt.Errorf("expected %s to be an int", param)
|
49 |
-
}
|
50 |
-
case "learning_rate":
|
51 |
-
if val, ok := i.(float64); ok {
|
52 |
-
nnArgs[param] = val
|
53 |
-
} else {
|
54 |
-
return nil, fmt.Errorf("expected %s to be a float64", param)
|
55 |
-
}
|
56 |
-
default:
|
57 |
-
return nil, fmt.Errorf("unsupported parameter: %s", param)
|
58 |
-
}
|
59 |
-
}
|
60 |
-
|
61 |
-
args := NewArgs(nnArgs)
|
62 |
-
return &NN{
|
63 |
-
args: args,
|
64 |
-
}, nil
|
65 |
-
|
66 |
}
|
|
|
2 |
|
3 |
import (
|
4 |
"fmt"
|
5 |
+
"strings"
|
6 |
|
7 |
+
"github.com/go-gota/gota/dataframe"
|
8 |
+
"github.com/gofiber/fiber/v2"
|
9 |
)
|
10 |
|
11 |
type NN struct {
|
12 |
+
CSVData string `json:"csv_data"`
|
13 |
+
Features []string `json:"features"`
|
14 |
+
Target string `json:"target"`
|
15 |
+
Epochs int `json:"epochs"`
|
16 |
+
HiddenSize int `json:"hidden_size"`
|
17 |
+
LearningRate float64 `json:"learning_rate"`
|
18 |
+
ActivationFunc string `json:"activation"`
|
19 |
+
Df dataframe.DataFrame
|
20 |
}
|
21 |
|
22 |
+
func NewNN(c *fiber.Ctx) (*NN, error) {
|
23 |
+
newNN := new(NN)
|
24 |
+
err := c.BodyParser(newNN)
|
25 |
+
if err != nil {
|
26 |
+
return nil, fmt.Errorf("invalid JSON data: %v", err)
|
|
|
|
|
|
|
27 |
}
|
28 |
+
newNN.Df = dataframe.ReadCSV(strings.NewReader(newNN.CSVData))
|
29 |
+
return newNN, nil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
}
|
request/request.go
DELETED
@@ -1,29 +0,0 @@
|
|
1 |
-
package request
|
2 |
-
|
3 |
-
import (
|
4 |
-
"fmt"
|
5 |
-
|
6 |
-
"github.com/go-gota/gota/dataframe"
|
7 |
-
"github.com/gofiber/fiber/v2"
|
8 |
-
)
|
9 |
-
|
10 |
-
type Payload struct {
|
11 |
-
CSVData string `json:"csv_data"`
|
12 |
-
Features []string `json:"features"`
|
13 |
-
Target string `json:"target"`
|
14 |
-
Algorithm string `json:"algorithm"`
|
15 |
-
|
16 |
-
Args map[string]interface{} `json:"args"`
|
17 |
-
Df dataframe.DataFrame
|
18 |
-
}
|
19 |
-
|
20 |
-
func (p *Payload) SetDf(df dataframe.DataFrame) {
|
21 |
-
p.Df = df
|
22 |
-
}
|
23 |
-
|
24 |
-
func NewPayload(dest *Payload, c *fiber.Ctx) error {
|
25 |
-
if err := c.BodyParser(dest); err != nil {
|
26 |
-
return fmt.Errorf("invalid JSON data")
|
27 |
-
}
|
28 |
-
return nil
|
29 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
server.go
CHANGED
@@ -1,41 +1,26 @@
|
|
1 |
package main
|
2 |
|
3 |
import (
|
4 |
-
"strings"
|
5 |
-
|
6 |
-
"github.com/Jensen-holm/ml-from-scratch/alg"
|
7 |
"github.com/Jensen-holm/ml-from-scratch/nn"
|
8 |
-
"github.com/Jensen-holm/ml-from-scratch/request"
|
9 |
-
"github.com/go-gota/gota/dataframe"
|
10 |
"github.com/gofiber/fiber/v2"
|
11 |
)
|
12 |
|
13 |
func main() {
|
14 |
app := fiber.New()
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
app.Post("/", func(c *fiber.Ctx) error {
|
21 |
-
r := new(request.Payload)
|
22 |
|
23 |
-
err :=
|
24 |
if err != nil {
|
25 |
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
26 |
-
"error":
|
27 |
})
|
28 |
}
|
29 |
|
30 |
-
df := dataframe.ReadCSV(strings.NewReader(r.CSVData))
|
31 |
-
r.SetDf(df)
|
32 |
-
|
33 |
-
a := algMap[r.Algorithm]
|
34 |
-
a.New(r)
|
35 |
-
|
36 |
return c.SendString("No error")
|
37 |
})
|
38 |
|
39 |
app.Listen(":3000")
|
40 |
-
|
41 |
}
|
|
|
1 |
package main
|
2 |
|
3 |
import (
|
|
|
|
|
|
|
4 |
"github.com/Jensen-holm/ml-from-scratch/nn"
|
|
|
|
|
5 |
"github.com/gofiber/fiber/v2"
|
6 |
)
|
7 |
|
8 |
func main() {
|
9 |
app := fiber.New()
|
10 |
|
11 |
+
// eventually we might want to add a key to this endpoint
|
12 |
+
// that we will be able to validate.
|
13 |
+
app.Post("/neural-network", func(c *fiber.Ctx) error {
|
|
|
|
|
|
|
14 |
|
15 |
+
nn, err := nn.NewNN(c)
|
16 |
if err != nil {
|
17 |
return c.Status(fiber.StatusBadRequest).JSON(fiber.Map{
|
18 |
+
"error": err,
|
19 |
})
|
20 |
}
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
return c.SendString("No error")
|
23 |
})
|
24 |
|
25 |
app.Listen(":3000")
|
|
|
26 |
}
|