Spaces:
Sleeping
Sleeping
Commit
·
0ad6252
1
Parent(s):
22ba16b
Update app.py
Browse files
app.py
CHANGED
@@ -1,7 +1,6 @@
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
from neuralprophet import NeuralProphet
|
4 |
-
import io
|
5 |
import warnings
|
6 |
from torch.optim.lr_scheduler import OneCycleLR
|
7 |
|
@@ -16,6 +15,10 @@ df.dropna(inplace=True)
|
|
16 |
|
17 |
|
18 |
class CustomNeuralProphet(NeuralProphet):
|
|
|
|
|
|
|
|
|
19 |
def lr_scheduler_step(self, epoch, batch_idx, optimizer):
|
20 |
# Custom logic for OneCycleLR scheduler step
|
21 |
for param_group in optimizer.param_groups:
|
@@ -23,8 +26,6 @@ class CustomNeuralProphet(NeuralProphet):
|
|
23 |
lr_scheduler = param_group["lr_scheduler"]
|
24 |
lr_scheduler.step()
|
25 |
|
26 |
-
optimizer = None
|
27 |
-
|
28 |
|
29 |
m = CustomNeuralProphet(
|
30 |
n_forecasts=30,
|
@@ -44,9 +45,10 @@ m = CustomNeuralProphet(
|
|
44 |
)
|
45 |
|
46 |
# Set the custom LR scheduler
|
47 |
-
optimizer = m.optimizer
|
48 |
-
|
49 |
-
|
|
|
50 |
"interval": "step",
|
51 |
"frequency": 1,
|
52 |
}
|
@@ -77,3 +79,4 @@ if __name__ == "__main__":
|
|
77 |
|
78 |
|
79 |
|
|
|
|
1 |
import gradio as gr
|
2 |
import pandas as pd
|
3 |
from neuralprophet import NeuralProphet
|
|
|
4 |
import warnings
|
5 |
from torch.optim.lr_scheduler import OneCycleLR
|
6 |
|
|
|
15 |
|
16 |
|
17 |
class CustomNeuralProphet(NeuralProphet):
|
18 |
+
def __init__(self, **kwargs):
|
19 |
+
super().__init__(**kwargs)
|
20 |
+
self.optimizer = None
|
21 |
+
|
22 |
def lr_scheduler_step(self, epoch, batch_idx, optimizer):
|
23 |
# Custom logic for OneCycleLR scheduler step
|
24 |
for param_group in optimizer.param_groups:
|
|
|
26 |
lr_scheduler = param_group["lr_scheduler"]
|
27 |
lr_scheduler.step()
|
28 |
|
|
|
|
|
29 |
|
30 |
m = CustomNeuralProphet(
|
31 |
n_forecasts=30,
|
|
|
45 |
)
|
46 |
|
47 |
# Set the custom LR scheduler
|
48 |
+
m.optimizer = m.model.optimizers[0] # Use the optimizer from the model
|
49 |
+
|
50 |
+
m.optimizer.lr_scheduler = {
|
51 |
+
"scheduler": OneCycleLR(m.optimizer),
|
52 |
"interval": "step",
|
53 |
"frequency": 1,
|
54 |
}
|
|
|
79 |
|
80 |
|
81 |
|
82 |
+
|