TuanScientist commited on
Commit
602bdb4
·
1 Parent(s): 44862d1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -13
app.py CHANGED
@@ -1,14 +1,12 @@
1
  import gradio as gr
2
  import pandas as pd
3
- from neuralprophet import NeuralProphet, set_log_level
4
  import io
5
  import warnings
6
- import pytorch_lightning
7
 
8
  warnings.filterwarnings("ignore", category=UserWarning)
9
 
10
- set_log_level("ERROR")
11
-
12
  url = "VN Index Historical Data.csv"
13
  df = pd.read_csv(url)
14
  df = df[["Date", "Price"]]
@@ -16,24 +14,30 @@ df = df.rename(columns={"Date": "ds", "Price": "y"})
16
  df.fillna(method='ffill', inplace=True)
17
  df.dropna(inplace=True)
18
 
19
- m = NeuralProphet(
 
 
 
 
 
 
20
  n_forecasts=30,
21
  n_lags=12,
22
- changepoints_range=5,
23
- num_hidden_layers=6,
24
  yearly_seasonality=True,
25
  n_changepoints=150,
26
  trend_reg_threshold=False, # Disable trend regularization threshold
27
- d_hidden=9,
28
  global_normalization=True,
29
  seasonality_reg=1,
30
  unknown_data_normalization=True,
31
  seasonality_mode="multiplicative",
32
  drop_missing=True,
33
- learning_rate=0.1
34
  )
35
 
36
- m.fit(df, freq='D')
37
 
38
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
39
  forecast = m.predict(future)
@@ -43,13 +47,17 @@ def predict_vn_index(option=None):
43
  fig = m.plot(forecast)
44
  path = "forecast_plot.png"
45
  fig.savefig(path)
46
- return path
 
47
 
48
 
49
  if __name__ == "__main__":
50
  dropdown = gr.inputs.Dropdown(["VNIndex"], label="Choose an option", default="VNIndex")
51
- interface = gr.Interface(fn=predict_vn_index, inputs=dropdown, outputs="image", title="Dự báo VN Index 30 ngày tới")
52
- interface.launch()
 
 
 
53
 
54
 
55
 
 
1
  import gradio as gr
2
  import pandas as pd
3
+ from neuralprophet import NeuralProphet
4
  import io
5
  import warnings
6
+ import torch
7
 
8
  warnings.filterwarnings("ignore", category=UserWarning)
9
 
 
 
10
  url = "VN Index Historical Data.csv"
11
  df = pd.read_csv(url)
12
  df = df[["Date", "Price"]]
 
14
  df.fillna(method='ffill', inplace=True)
15
  df.dropna(inplace=True)
16
 
17
+ class CustomNeuralProphet(NeuralProphet):
18
+ def lr_scheduler_step(self, epoch: int = None) -> None:
19
+ # Override the lr_scheduler_step method to avoid the MisconfigurationException
20
+ if self.lr_scheduler is not None and isinstance(self.lr_scheduler, torch.optim.lr_scheduler.OneCycleLR):
21
+ self.lr_scheduler.step()
22
+
23
+ m = CustomNeuralProphet(
24
  n_forecasts=30,
25
  n_lags=12,
26
+ changepoints_range=1,
27
+ num_hidden_layers=3,
28
  yearly_seasonality=True,
29
  n_changepoints=150,
30
  trend_reg_threshold=False, # Disable trend regularization threshold
31
+ d_hidden=3,
32
  global_normalization=True,
33
  seasonality_reg=1,
34
  unknown_data_normalization=True,
35
  seasonality_mode="multiplicative",
36
  drop_missing=True,
37
+ learning_rate=0.1,
38
  )
39
 
40
+ m.fit(df, freq='D', epochs=10, validate_each_epoch=True, valid_p=0.2) # Specify number of epochs and validation parameters
41
 
42
  future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
43
  forecast = m.predict(future)
 
47
  fig = m.plot(forecast)
48
  path = "forecast_plot.png"
49
  fig.savefig(path)
50
+ disclaimer = "Quý khách chỉ xem đây là tham khảo, công ty không chịu bất cứ trách nhiệm nào về tình trạng đầu tư của quý khách."
51
+ return path, disclaimer
52
 
53
 
54
  if __name__ == "__main__":
55
  dropdown = gr.inputs.Dropdown(["VNIndex"], label="Choose an option", default="VNIndex")
56
+ image_output = gr.outputs.Image(type="file", label="Forecast Plot")
57
+ disclaimer_output = gr.outputs.Textbox(label="Disclaimer")
58
+ interface = gr.Interface(fn=predict_vn_index, inputs=dropdown, outputs=[image_output, disclaimer_output], title="Dự báo VN Index 30 ngày tới")
59
+ interface.launch(share=True)
60
+
61
 
62
 
63