Kr08's picture
Update app.py
6a50720 verified
raw
history blame
4.76 kB
import datetime
import gradio as gr
import pandas as pd
import yfinance as yf
import seaborn as sns
sns.set()
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from datetime import date, timedelta
from matplotlib import pyplot as plt
from plotly.subplots import make_subplots
from pytickersymbols import PyTickerSymbols
from statsmodels.tsa.arima.model import ARIMA
from pandas.plotting import autocorrelation_plot
from dateutil.relativedelta import relativedelta
index_options = ['FTSE 100(UK)', 'NASDAQ(USA)', 'CAC 40(FRANCE)']
ticker_dict = {'FTSE 100(UK)': 'FTSE 100', 'NASDAQ(USA)': 'NASDAQ 100', 'CAC 40(FRANCE)': 'CAC 40'}
time_intervals = ['1d', '1m', '5m', '15m', '60m']
global START_DATE, END_DATE
END_DATE = date.today()
START_DATE = END_DATE - relativedelta(years=1)
FORECAST_PERIOD = 7
demo = gr.Blocks()
stock_names = []
with demo:
d1 = gr.Dropdown(index_options, label='Please select Index...', info='Will be adding more indices later on', interactive=True)
d2 = gr.Dropdown([], label='Please Select Stock from your selected index', interactive=True)
d3 = gr.Dropdown(time_intervals, label='Select Time Interval', value='1d', interactive=True)
d4 = gr.Radio(['Line Graph', 'Candlestick Graph'], label='Select Graph Type', value='Line Graph', interactive=True)
d5 = gr.Dropdown(['ARIMA', 'Prophet', 'LSTM'], label='Select Forecasting Method', value='ARIMA', interactive=True)
def forecast_series(series, model="ARIMA", forecast_period=7):
predictions = list()
if series.shape[1] > 1:
series = series['Close'].values.tolist()
if model == "ARIMA":
for i in range(forecast_period):
model = ARIMA(series, order=(5, 1, 0))
model_fit = model.fit()
output = model_fit.forecast()
yhat = output[0]
predictions.append(yhat)
series.append(yhat)
elif model == "Prophet":
# Implement Prophet forecasting method
pass
elif model == "LSTM":
# Implement LSTM forecasting method
pass
return predictions
def is_business_day(a_date):
return a_date.weekday() < 5
def get_stocks_from_index(idx):
stock_data = PyTickerSymbols()
index = ticker_dict[idx]
stocks = list(stock_data.get_stocks_by_index(index))
stock_names = [f"{stock['name']}:{stock['symbol']}" for stock in stocks]
return gr.Dropdown(choices=stock_names, label='Please Select Stock from your selected index', interactive=True)
d1.input(get_stocks_from_index, d1, d2)
def get_stock_graph(idx, stock, interval, graph_type, forecast_method):
stock_name, ticker_name = stock.split(":")
if ticker_dict[idx] == 'FTSE 100':
ticker_name += '.L' if ticker_name[-1] != '.' else 'L'
elif ticker_dict[idx] == 'CAC 40':
ticker_name += '.PA'
series = yf.download(tickers=ticker_name, start=START_DATE, end=END_DATE, interval=interval)
series = series.reset_index()
predictions = forecast_series(series, model=forecast_method)
last_date = pd.to_datetime(series['Date'].values[-1])
forecast_week = [last_date + timedelta(days=i) for i in range(1, FORECAST_PERIOD + 1) if is_business_day(last_date + timedelta(days=i))]
forecast = pd.DataFrame({"Date": forecast_week, "Forecast": predictions})
if graph_type == 'Line Graph':
fig = go.Figure()
fig.add_trace(go.Scatter(x=series['Date'], y=series['Close'], mode='lines', name='Historical'))
fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
else: # Candlestick Graph
fig = go.Figure(data=[go.Candlestick(x=series['Date'],
open=series['Open'],
high=series['High'],
low=series['Low'],
close=series['Close'],
name='Historical')])
fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
fig.update_layout(title=f"Stock Price of {stock_name}",
xaxis_title="Date",
yaxis_title="Price")
return fig
out = gr.Plot()
inputs = [d1, d2, d3, d4, d5]
d2.input(get_stock_graph, inputs, out)
d3.input(get_stock_graph, inputs, out)
d4.input(get_stock_graph, inputs, out)
d5.input(get_stock_graph, inputs, out)
demo.launch()