hectorduran commited on
Commit
e997929
·
verified ·
1 Parent(s): 33812a6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yfinance as yf
2
+ import pandas as pd
3
+ import numpy as np
4
+ from prophet import Prophet
5
+ import plotly.graph_objs as go
6
+ import plotly.express as px
7
+ import gradio as gr
8
+ from pmdarima import auto_arima
9
+
10
+ def forecast_stock(ticker, period, future_days, use_arima):
11
+ # Fetch data
12
+ data = yf.Ticker(ticker)
13
+ df = data.history(period=period)
14
+
15
+ if df.empty:
16
+ return "Could not retrieve data for the selected ticker."
17
+
18
+ df = df.reset_index()
19
+ df = df[['Date', 'Close']]
20
+ df.columns = ['ds', 'y']
21
+ df['ds'] = pd.to_datetime(df['ds']).dt.tz_localize(None)
22
+ df = df.dropna()
23
+
24
+ # Prophet forecast
25
+ model = Prophet()
26
+ model.fit(df)
27
+ future_dates = model.make_future_dataframe(periods=future_days)
28
+ forecast = model.predict(future_dates)
29
+
30
+ # Create Plotly figure for forecast
31
+ fig_forecast = go.Figure()
32
+ fig_forecast.add_trace(go.Scatter(x=df['ds'], y=df['y'], name='Historical'))
33
+ fig_forecast.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat'], name='Prophet Forecast'))
34
+ fig_forecast.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_upper'], name='Prophet Upper Bound', line=dict(dash='dash')))
35
+ fig_forecast.add_trace(go.Scatter(x=forecast['ds'], y=forecast['yhat_lower'], name='Prophet Lower Bound', line=dict(dash='dash')))
36
+
37
+ if use_arima:
38
+ # ARIMA forecast with automatic order selection
39
+ model_arima = auto_arima(df['y'], seasonal=False, trace=True)
40
+ results_arima = model_arima.fit(df['y'])
41
+ arima_forecast = results_arima.predict(n_periods=future_days)
42
+ future_dates_arima = pd.date_range(start=df['ds'].iloc[-1] + pd.Timedelta(days=1), periods=future_days)
43
+
44
+ fig_forecast.add_trace(go.Scatter(x=future_dates_arima, y=arima_forecast, name='ARIMA Forecast'))
45
+
46
+ fig_forecast.update_layout(title=f'Stock Price Forecast for {ticker}', xaxis_title='Date', yaxis_title='Stock Price')
47
+
48
+ # Create Plotly figure for Prophet components
49
+ fig_components = px.line(forecast, x='ds', y=['trend', 'yearly', 'weekly'])
50
+ fig_components.update_layout(title='Prophet Forecast Components')
51
+
52
+ return fig_forecast, fig_components
53
+
54
+ # Define Gradio interface
55
+ iface = gr.Interface(
56
+ fn=forecast_stock,
57
+ inputs=[
58
+ gr.Dropdown(choices=['AAPL', 'GOOGL', 'MSFT', 'AMZN'], label="Stock Ticker"),
59
+ gr.Dropdown(choices=['1y', '2y', '5y', '10y', 'max'], label="Historical Data Period"),
60
+ gr.Slider(minimum=30, maximum=365, step=30, label="Days to Forecast"),
61
+ gr.Checkbox(label="Include ARIMA Forecast")
62
+ ],
63
+ outputs=[
64
+ gr.Plot(label="Forecast"),
65
+ gr.Plot(label="Prophet Forecast Components")
66
+ ],
67
+ title="Stock Price Forecasting with Prophet and ARIMA",
68
+ description="Select a stock, historical data period, forecast horizon, and whether to include ARIMA forecast."
69
+ )
70
+
71
+ # Launch the interface
72
+ iface.launch()