TuanScientist commited on
Commit
49c8dfd
·
1 Parent(s): 59a1ae2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -23
app.py CHANGED
@@ -3,7 +3,7 @@ import pandas as pd
3
  from neuralprophet import NeuralProphet
4
  import io
5
  import warnings
6
- import torch
7
 
8
  warnings.filterwarnings("ignore", category=UserWarning)
9
 
@@ -15,10 +15,12 @@ df.fillna(method='ffill', inplace=True)
15
  df.dropna(inplace=True)
16
 
17
  class CustomNeuralProphet(NeuralProphet):
18
- def lr_scheduler_step(self, epoch: int = None) -> None:
19
- # Override the lr_scheduler_step method to avoid the MisconfigurationException
20
- if self.lr_scheduler is not None and isinstance(self.lr_scheduler, torch.optim.lr_scheduler.OneCycleLR):
21
- self.lr_scheduler.step()
 
 
22
 
23
  m = CustomNeuralProphet(
24
  n_forecasts=30,
@@ -27,7 +29,7 @@ m = CustomNeuralProphet(
27
  num_hidden_layers=3,
28
  yearly_seasonality=True,
29
  n_changepoints=150,
30
- trend_reg_threshold=False, # Disable trend regularization threshold
31
  d_hidden=3,
32
  global_normalization=True,
33
  seasonality_reg=1,
@@ -44,28 +46,19 @@ forecast = m.predict(future)
44
 
45
 
46
  def predict_vn_index(option=None):
47
- fig1 = m.plot(forecast)
48
- fig1_path = "forecast_plot1.png"
49
- fig1.savefig(fig1_path)
50
-
51
- # Add code to generate the second image (fig2)
52
- fig2 = m.plot_latest_forecast(forecast) # Replace this line with code to generate the second image
53
- fig2_path = "forecast_plot2.png"
54
- fig2.savefig(fig2_path)
55
  disclaimer = "Quý khách chỉ xem đây là tham khảo, công ty không chịu bất cứ trách nhiệm nào về tình trạng đầu tư của quý khách."
56
-
57
- return fig1_path, fig2_path, disclaimer
58
 
59
 
60
  if __name__ == "__main__":
61
  dropdown = gr.inputs.Dropdown(["VNIndex"], label="Choose an option", default="VNIndex")
62
- outputs = [
63
- gr.outputs.Image(type="filepath", label="First Image"),
64
- gr.outputs.Image(type="filepath", label="Second Image"),
65
- gr.outputs.Textbox(label="Disclaimer")
66
- ]
67
- interface = gr.Interface(fn=predict_vn_index, inputs=dropdown, outputs=outputs, title="Dự báo VN Index 30 ngày tới")
68
- interface.launch()
69
 
70
 
71
 
 
3
  from neuralprophet import NeuralProphet
4
  import io
5
  import warnings
6
+ import torch.optim.lr_scheduler as lr_scheduler
7
 
8
  warnings.filterwarnings("ignore", category=UserWarning)
9
 
 
15
  df.dropna(inplace=True)
16
 
17
  class CustomNeuralProphet(NeuralProphet):
18
+ def __init__(self, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.trainer.optimizers[0], mode='min', factor=0.5, patience=10, verbose=True)
21
+
22
+ def lr_scheduler_step(self, metrics):
23
+ self.lr_scheduler.step(metrics)
24
 
25
  m = CustomNeuralProphet(
26
  n_forecasts=30,
 
29
  num_hidden_layers=3,
30
  yearly_seasonality=True,
31
  n_changepoints=150,
32
+ trend_reg_threshold=False,
33
  d_hidden=3,
34
  global_normalization=True,
35
  seasonality_reg=1,
 
46
 
47
 
48
  def predict_vn_index(option=None):
49
+ fig = m.plot(forecast)
50
+ path = "forecast_plot.png"
51
+ fig.savefig(path)
 
 
 
 
 
52
  disclaimer = "Quý khách chỉ xem đây là tham khảo, công ty không chịu bất cứ trách nhiệm nào về tình trạng đầu tư của quý khách."
53
+ return path, disclaimer
 
54
 
55
 
56
  if __name__ == "__main__":
57
  dropdown = gr.inputs.Dropdown(["VNIndex"], label="Choose an option", default="VNIndex")
58
+ image_output = gr.outputs.Image(type="file", label="Forecast Plot")
59
+ disclaimer_output = gr.outputs.Textbox(label="Disclaimer")
60
+ interface = gr.Interface(fn=predict_vn_index, inputs=dropdown, outputs=[image_output, disclaimer_output], title="Dự báo VN Index 30 ngày tới")
61
+ interface.launch(share=True)
 
 
 
62
 
63
 
64