TuanScientist commited on
Commit
9bd34af
·
1 Parent(s): b673837

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -8
app.py CHANGED
@@ -2,7 +2,8 @@ import gradio as gr
2
  import pandas as pd
3
  from neuralprophet import NeuralProphet
4
  import warnings
5
- from torch.optim.lr_scheduler import OneCycleLR
 
6
 
7
  warnings.filterwarnings("ignore", category=UserWarning)
8
 
@@ -20,12 +21,9 @@ class CustomNeuralProphet(NeuralProphet):
20
  self.optimizer = None
21
 
22
  def lr_scheduler_step(self, epoch, batch_idx, optimizer):
23
- # Custom logic for OneCycleLR scheduler step
24
- for param_group in optimizer.param_groups:
25
- if "lr_scheduler" in param_group:
26
- lr_scheduler = param_group["lr_scheduler"]
27
- lr_scheduler.step()
28
-
29
 
30
  m = CustomNeuralProphet(
31
  n_forecasts=30,
@@ -44,8 +42,12 @@ m = CustomNeuralProphet(
44
  learning_rate=0.03,
45
  )
46
 
 
 
 
47
 
48
- m.fit(df, freq='D')
 
49
 
50
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
51
  forecast = m.predict(future)
@@ -72,3 +74,4 @@ if __name__ == "__main__":
72
 
73
 
74
 
 
 
2
  import pandas as pd
3
  from neuralprophet import NeuralProphet
4
  import warnings
5
+ import torch.optim as optim
6
+ from torch.optim.lr_scheduler import LambdaLR
7
 
8
  warnings.filterwarnings("ignore", category=UserWarning)
9
 
 
21
  self.optimizer = None
22
 
23
  def lr_scheduler_step(self, epoch, batch_idx, optimizer):
24
+ # Custom logic for LR scheduler step
25
+ for lr_scheduler in optimizer.param_groups[0]['lr_scheduler']:
26
+ lr_scheduler.step()
 
 
 
27
 
28
  m = CustomNeuralProphet(
29
  n_forecasts=30,
 
42
  learning_rate=0.03,
43
  )
44
 
45
+ # Set the custom LR scheduler
46
+ m.fit(df, freq='D') # Fit the model first before accessing the optimizer
47
+ m.optimizer = optim.Adam(m.model.parameters(), lr=0.03) # Example optimizer, adjust as needed
48
 
49
+ lr_scheduler = LambdaLR(m.optimizer, lambda epoch: 0.95 ** epoch) # Example LR scheduler, adjust as needed
50
+ m.optimizer.param_groups[0]['lr_scheduler'] = [lr_scheduler]
51
 
52
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
53
  forecast = m.predict(future)
 
74
 
75
 
76
 
77
+