TuanScientist commited on
Commit
9eca625
·
1 Parent(s): f37e124

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -3,6 +3,8 @@ import pandas as pd
3
  from neuralprophet import NeuralProphet
4
  import io
5
  import warnings
 
 
6
 
7
  warnings.filterwarnings("ignore", category=UserWarning)
8
 
@@ -13,13 +15,14 @@ df = df.rename(columns={"Date": "ds", "Price": "y"})
13
  df.fillna(method='ffill', inplace=True)
14
  df.dropna(inplace=True)
15
 
 
16
  class CustomNeuralProphet(NeuralProphet):
17
  def lr_scheduler_step(self, epoch, batch_idx, optimizer):
18
- # Get the OneCycleLR scheduler
19
- scheduler = self.optimizers[0].scheduler
20
-
21
- # Call the `step` method on the scheduler
22
- scheduler.step()
23
 
24
  m = CustomNeuralProphet(
25
  n_forecasts=30,
@@ -38,6 +41,14 @@ m = CustomNeuralProphet(
38
  learning_rate=0.03,
39
  )
40
 
 
 
 
 
 
 
 
 
41
  m.fit(df, freq='D')
42
 
43
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
 
3
  from neuralprophet import NeuralProphet
4
  import io
5
  import warnings
6
+ from torch.optim.lr_scheduler import OneCycleLR
7
+ from torch.optim.optimizer import Optimizer
8
 
9
  warnings.filterwarnings("ignore", category=UserWarning)
10
 
 
15
  df.fillna(method='ffill', inplace=True)
16
  df.dropna(inplace=True)
17
 
18
+
19
  class CustomNeuralProphet(NeuralProphet):
20
  def lr_scheduler_step(self, epoch, batch_idx, optimizer):
21
+ # Custom logic for OneCycleLR scheduler step
22
+ for param_group in optimizer.param_groups:
23
+ if "lr_scheduler" in param_group:
24
+ lr_scheduler = param_group["lr_scheduler"]
25
+ lr_scheduler.step()
26
 
27
  m = CustomNeuralProphet(
28
  n_forecasts=30,
 
41
  learning_rate=0.03,
42
  )
43
 
44
+ # Set the custom LR scheduler
45
+ optimizer = m.trainer.optimizers[0]
46
+ optimizer.lr_scheduler = {
47
+ "scheduler": OneCycleLR(optimizer, max_lr=0.1, steps_per_epoch=100, epochs=10),
48
+ "interval": "step",
49
+ "frequency": 1,
50
+ }
51
+
52
  m.fit(df, freq='D')
53
 
54
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)