TuanScientist commited on
Commit
d54f001
·
1 Parent(s): 9bd34af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -16
app.py CHANGED
@@ -3,7 +3,7 @@ import pandas as pd
3
  from neuralprophet import NeuralProphet
4
  import warnings
5
  import torch.optim as optim
6
- from torch.optim.lr_scheduler import LambdaLR
7
 
8
  warnings.filterwarnings("ignore", category=UserWarning)
9
 
@@ -20,11 +20,6 @@ class CustomNeuralProphet(NeuralProphet):
20
  super().__init__(**kwargs)
21
  self.optimizer = None
22
 
23
- def lr_scheduler_step(self, epoch, batch_idx, optimizer):
24
- # Custom logic for LR scheduler step
25
- for lr_scheduler in optimizer.param_groups[0]['lr_scheduler']:
26
- lr_scheduler.step()
27
-
28
  m = CustomNeuralProphet(
29
  n_forecasts=30,
30
  n_lags=12,
@@ -46,8 +41,15 @@ m = CustomNeuralProphet(
46
  m.fit(df, freq='D') # Fit the model first before accessing the optimizer
47
  m.optimizer = optim.Adam(m.model.parameters(), lr=0.03) # Example optimizer, adjust as needed
48
 
49
- lr_scheduler = LambdaLR(m.optimizer, lambda epoch: 0.95 ** epoch) # Example LR scheduler, adjust as needed
50
- m.optimizer.param_groups[0]['lr_scheduler'] = [lr_scheduler]
 
 
 
 
 
 
 
51
 
52
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
53
  forecast = m.predict(future)
@@ -67,11 +69,3 @@ if __name__ == "__main__":
67
  disclaimer_output = gr.outputs.Textbox(label="Disclaimer")
68
  interface = gr.Interface(fn=predict_vn_index, inputs=dropdown, outputs=[image_output, disclaimer_output], title="Dự báo VN Index 30 ngày tới")
69
  interface.launch()
70
-
71
-
72
-
73
-
74
-
75
-
76
-
77
-
 
3
  from neuralprophet import NeuralProphet
4
  import warnings
5
  import torch.optim as optim
6
+ from torch.optim.lr_scheduler import OneCycleLR
7
 
8
  warnings.filterwarnings("ignore", category=UserWarning)
9
 
 
20
  super().__init__(**kwargs)
21
  self.optimizer = None
22
 
 
 
 
 
 
23
  m = CustomNeuralProphet(
24
  n_forecasts=30,
25
  n_lags=12,
 
41
  m.fit(df, freq='D') # Fit the model first before accessing the optimizer
42
  m.optimizer = optim.Adam(m.model.parameters(), lr=0.03) # Example optimizer, adjust as needed
43
 
44
+ lr_scheduler = OneCycleLR(
45
+ m.optimizer,
46
+ max_lr=0.1,
47
+ total_steps=100,
48
+ pct_start=0.3,
49
+ anneal_strategy='cos',
50
+ ) # Example LR scheduler, adjust as needed
51
+
52
+ m.trainer.lr_schedulers = [lr_scheduler] # Set the LR scheduler to the trainer
53
 
54
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
55
  forecast = m.predict(future)
 
69
  disclaimer_output = gr.outputs.Textbox(label="Disclaimer")
70
  interface = gr.Interface(fn=predict_vn_index, inputs=dropdown, outputs=[image_output, disclaimer_output], title="Dự báo VN Index 30 ngày tới")
71
  interface.launch()