TuanScientist commited on
Commit
0ad6252
·
1 Parent(s): 22ba16b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -6
app.py CHANGED
@@ -1,7 +1,6 @@
1
  import gradio as gr
2
  import pandas as pd
3
  from neuralprophet import NeuralProphet
4
- import io
5
  import warnings
6
  from torch.optim.lr_scheduler import OneCycleLR
7
 
@@ -16,6 +15,10 @@ df.dropna(inplace=True)
16
 
17
 
18
  class CustomNeuralProphet(NeuralProphet):
 
 
 
 
19
  def lr_scheduler_step(self, epoch, batch_idx, optimizer):
20
  # Custom logic for OneCycleLR scheduler step
21
  for param_group in optimizer.param_groups:
@@ -23,8 +26,6 @@ class CustomNeuralProphet(NeuralProphet):
23
  lr_scheduler = param_group["lr_scheduler"]
24
  lr_scheduler.step()
25
 
26
- optimizer = None
27
-
28
 
29
  m = CustomNeuralProphet(
30
  n_forecasts=30,
@@ -44,9 +45,10 @@ m = CustomNeuralProphet(
44
  )
45
 
46
  # Set the custom LR scheduler
47
- optimizer = m.optimizer
48
- optimizer.lr_scheduler = {
49
- "scheduler": OneCycleLR(optimizer),
 
50
  "interval": "step",
51
  "frequency": 1,
52
  }
@@ -77,3 +79,4 @@ if __name__ == "__main__":
77
 
78
 
79
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  from neuralprophet import NeuralProphet
 
4
  import warnings
5
  from torch.optim.lr_scheduler import OneCycleLR
6
 
 
15
 
16
 
17
  class CustomNeuralProphet(NeuralProphet):
18
+ def __init__(self, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.optimizer = None
21
+
22
  def lr_scheduler_step(self, epoch, batch_idx, optimizer):
23
  # Custom logic for OneCycleLR scheduler step
24
  for param_group in optimizer.param_groups:
 
26
  lr_scheduler = param_group["lr_scheduler"]
27
  lr_scheduler.step()
28
 
 
 
29
 
30
  m = CustomNeuralProphet(
31
  n_forecasts=30,
 
45
  )
46
 
47
  # Set the custom LR scheduler
48
+ m.optimizer = m.model.optimizers[0] # Use the optimizer from the model
49
+
50
+ m.optimizer.lr_scheduler = {
51
+ "scheduler": OneCycleLR(m.optimizer),
52
  "interval": "step",
53
  "frequency": 1,
54
  }
 
79
 
80
 
81
 
82
+