Spaces:
Sleeping
Sleeping
File size: 2,272 Bytes
fc9d5c1 602bdb4 fc9d5c1 9bd34af d54f001 fc9d5c1 9eca625 21cc39d 0ad6252 602bdb4 fc9d5c1 602bdb4 fc9d5c1 49c8dfd 602bdb4 fc9d5c1 59a1ae2 fc9d5c1 9bd34af 9eca625 d54f001 fc9d5c1 49c8dfd 602bdb4 49c8dfd fc9d5c1 21cc39d 49c8dfd 3faa322 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import gradio as gr
import pandas as pd
from neuralprophet import NeuralProphet
import warnings
import torch.optim as optim
from torch.optim.lr_scheduler import OneCycleLR
warnings.filterwarnings("ignore", category=UserWarning)
url = "VN Index Historical Data.csv"
df = pd.read_csv(url)
df = df[["Date", "Price"]]
df = df.rename(columns={"Date": "ds", "Price": "y"})
df.fillna(method='ffill', inplace=True)
df.dropna(inplace=True)
class CustomNeuralProphet(NeuralProphet):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.optimizer = None
m = CustomNeuralProphet(
n_forecasts=30,
n_lags=12,
changepoints_range=1,
num_hidden_layers=3,
yearly_seasonality=True,
n_changepoints=150,
trend_reg_threshold=False,
d_hidden=3,
global_normalization=True,
seasonality_reg=1,
unknown_data_normalization=True,
seasonality_mode="multiplicative",
drop_missing=True,
learning_rate=0.03,
)
# Set the custom LR scheduler
m.fit(df, freq='D') # Fit the model first before accessing the optimizer
m.optimizer = optim.Adam(m.model.parameters(), lr=0.03) # Example optimizer, adjust as needed
lr_scheduler = OneCycleLR(
m.optimizer,
max_lr=0.1,
total_steps=100,
pct_start=0.3,
anneal_strategy='cos',
) # Example LR scheduler, adjust as needed
m.trainer.lr_schedulers = [lr_scheduler] # Set the LR scheduler to the trainer
future = m.make_future_dataframe(df, periods=30, n_historic_predictions=True)
forecast = m.predict(future)
def predict_vn_index(option=None):
fig = m.plot(forecast)
path = "forecast_plot.png"
fig.savefig(path)
disclaimer = "Quý khách chỉ xem đây là tham khảo, công ty không chịu bất cứ trách nhiệm nào về tình trạng đầu tư của quý khách."
return path, disclaimer
if __name__ == "__main__":
dropdown = gr.inputs.Dropdown(["VNIndex"], label="Choose an option", default="VNIndex")
image_output = gr.outputs.Image(type="file", label="Forecast Plot")
disclaimer_output = gr.outputs.Textbox(label="Disclaimer")
interface = gr.Interface(fn=predict_vn_index, inputs=dropdown, outputs=[image_output, disclaimer_output], title="Dự báo VN Index 30 ngày tới")
interface.launch()
|