Kr08 commited on
Commit
6a50720
·
verified ·
1 Parent(s): 7b389b1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -56
app.py CHANGED
@@ -2,8 +2,7 @@ import datetime
2
  import gradio as gr
3
  import pandas as pd
4
  import yfinance as yf
5
- import seaborn as sns;
6
-
7
  sns.set()
8
  import matplotlib.pyplot as plt
9
  import plotly.graph_objects as go
@@ -18,6 +17,7 @@ from dateutil.relativedelta import relativedelta
18
 
19
  index_options = ['FTSE 100(UK)', 'NASDAQ(USA)', 'CAC 40(FRANCE)']
20
  ticker_dict = {'FTSE 100(UK)': 'FTSE 100', 'NASDAQ(USA)': 'NASDAQ 100', 'CAC 40(FRANCE)': 'CAC 40'}
 
21
 
22
  global START_DATE, END_DATE
23
 
@@ -28,23 +28,18 @@ demo = gr.Blocks()
28
  stock_names = []
29
 
30
  with demo:
31
- d1 = gr.Dropdown(index_options, label='Please select Index...',
32
- info='Will be adding more indices later on',
33
- interactive=True)
34
-
35
- d2 = gr.Dropdown([]) # for specific stocks
36
-
37
-
38
- # d3 = gr.Dropdown(['General News'])
39
 
40
  def forecast_series(series, model="ARIMA", forecast_period=7):
41
-
42
  predictions = list()
43
  if series.shape[1] > 1:
44
  series = series['Close'].values.tolist()
45
- plt.show()
46
  if model == "ARIMA":
47
- ## Do grid search here --> Custom for all stocks
48
  for i in range(forecast_period):
49
  model = ARIMA(series, order=(5, 1, 0))
50
  model_fit = model.fit()
@@ -52,74 +47,69 @@ with demo:
52
  yhat = output[0]
53
  predictions.append(yhat)
54
  series.append(yhat)
 
 
 
 
 
 
55
 
56
  return predictions
57
 
58
-
59
  def is_business_day(a_date):
60
  return a_date.weekday() < 5
61
 
62
-
63
  def get_stocks_from_index(idx):
64
  stock_data = PyTickerSymbols()
65
- # indices = stock_data.get_all_indices()
66
  index = ticker_dict[idx]
67
- stock_data = PyTickerSymbols()
68
-
69
- # returns 2d list with the following information
70
- # 'name', 'symbol', 'country', 'indices', 'industries', 'symbols', 'metadata', 'isins', 'akas'
71
- stocks = list(stock_data.get_stocks_by_index(index)) ##converting filter object to list
72
- stock_names = []
73
- for stock in stocks:
74
- stock_names.append(stock['name'] + ':' + stock['symbol'])
75
- d2 = gr.Dropdown(choices=stock_names, label='Please Select Stock from your selected index', interactive=True)
76
- return d2
77
-
78
 
79
  d1.input(get_stocks_from_index, d1, d2)
80
- out = gr.Plot(every=10)
81
-
82
-
83
- def get_stock_graph(idx, stock):
84
-
85
- stock_name = stock.split(":")[0]
86
- ticker_name = stock.split(":")[1]
87
 
 
 
 
88
  if ticker_dict[idx] == 'FTSE 100':
89
- if ticker_name[-1] == '.':
90
- ticker_name += 'L'
91
- else:
92
- ticker_name += '.L'
93
  elif ticker_dict[idx] == 'CAC 40':
94
  ticker_name += '.PA'
95
 
96
- ## Can also download lower interval data apparently using line below
97
- # data = yf.download(tickers="MSFT", period="5d", interval="1m")
98
- series = yf.download(tickers=ticker_name, start=START_DATE, end=END_DATE) # stock.split(":")[1]
99
  series = series.reset_index()
100
 
101
- predictions = forecast_series(series)
102
 
103
  last_date = pd.to_datetime(series['Date'].values[-1])
104
- forecast_week = []
105
-
106
- while len(forecast_week) != FORECAST_PERIOD:
107
- if is_business_day(last_date):
108
- forecast_week.append(last_date)
109
- last_date += timedelta(days=1)
110
 
111
  forecast = pd.DataFrame({"Date": forecast_week, "Forecast": predictions})
112
 
113
- fig = plt.figure(figsize=(14, 5))
114
- sns.set_style("ticks")
115
- sns.lineplot(data=series, x="Date", y="Close", color="firebrick")
116
- sns.lineplot(data=forecast, x="Date", y="Forecast", color="blue")
117
- sns.despine()
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- plt.title("Stock Price of {}".format(stock_name), size='x-large', color='blue') # stock.split(":")[0]
120
- text = "Your stock is:" + str(stock)
121
  return fig
122
 
 
 
 
 
 
 
123
 
124
- d2.input(get_stock_graph, [d1, d2], out)
125
  demo.launch()
 
2
  import gradio as gr
3
  import pandas as pd
4
  import yfinance as yf
5
+ import seaborn as sns
 
6
  sns.set()
7
  import matplotlib.pyplot as plt
8
  import plotly.graph_objects as go
 
17
 
18
  index_options = ['FTSE 100(UK)', 'NASDAQ(USA)', 'CAC 40(FRANCE)']
19
  ticker_dict = {'FTSE 100(UK)': 'FTSE 100', 'NASDAQ(USA)': 'NASDAQ 100', 'CAC 40(FRANCE)': 'CAC 40'}
20
+ time_intervals = ['1d', '1m', '5m', '15m', '60m']
21
 
22
  global START_DATE, END_DATE
23
 
 
28
  stock_names = []
29
 
30
  with demo:
31
+ d1 = gr.Dropdown(index_options, label='Please select Index...', info='Will be adding more indices later on', interactive=True)
32
+ d2 = gr.Dropdown([], label='Please Select Stock from your selected index', interactive=True)
33
+ d3 = gr.Dropdown(time_intervals, label='Select Time Interval', value='1d', interactive=True)
34
+ d4 = gr.Radio(['Line Graph', 'Candlestick Graph'], label='Select Graph Type', value='Line Graph', interactive=True)
35
+ d5 = gr.Dropdown(['ARIMA', 'Prophet', 'LSTM'], label='Select Forecasting Method', value='ARIMA', interactive=True)
 
 
 
36
 
37
  def forecast_series(series, model="ARIMA", forecast_period=7):
 
38
  predictions = list()
39
  if series.shape[1] > 1:
40
  series = series['Close'].values.tolist()
41
+
42
  if model == "ARIMA":
 
43
  for i in range(forecast_period):
44
  model = ARIMA(series, order=(5, 1, 0))
45
  model_fit = model.fit()
 
47
  yhat = output[0]
48
  predictions.append(yhat)
49
  series.append(yhat)
50
+ elif model == "Prophet":
51
+ # Implement Prophet forecasting method
52
+ pass
53
+ elif model == "LSTM":
54
+ # Implement LSTM forecasting method
55
+ pass
56
 
57
  return predictions
58
 
 
59
  def is_business_day(a_date):
60
  return a_date.weekday() < 5
61
 
 
62
  def get_stocks_from_index(idx):
63
  stock_data = PyTickerSymbols()
 
64
  index = ticker_dict[idx]
65
+ stocks = list(stock_data.get_stocks_by_index(index))
66
+ stock_names = [f"{stock['name']}:{stock['symbol']}" for stock in stocks]
67
+ return gr.Dropdown(choices=stock_names, label='Please Select Stock from your selected index', interactive=True)
 
 
 
 
 
 
 
 
68
 
69
  d1.input(get_stocks_from_index, d1, d2)
 
 
 
 
 
 
 
70
 
71
+ def get_stock_graph(idx, stock, interval, graph_type, forecast_method):
72
+ stock_name, ticker_name = stock.split(":")
73
+
74
  if ticker_dict[idx] == 'FTSE 100':
75
+ ticker_name += '.L' if ticker_name[-1] != '.' else 'L'
 
 
 
76
  elif ticker_dict[idx] == 'CAC 40':
77
  ticker_name += '.PA'
78
 
79
+ series = yf.download(tickers=ticker_name, start=START_DATE, end=END_DATE, interval=interval)
 
 
80
  series = series.reset_index()
81
 
82
+ predictions = forecast_series(series, model=forecast_method)
83
 
84
  last_date = pd.to_datetime(series['Date'].values[-1])
85
+ forecast_week = [last_date + timedelta(days=i) for i in range(1, FORECAST_PERIOD + 1) if is_business_day(last_date + timedelta(days=i))]
 
 
 
 
 
86
 
87
  forecast = pd.DataFrame({"Date": forecast_week, "Forecast": predictions})
88
 
89
+ if graph_type == 'Line Graph':
90
+ fig = go.Figure()
91
+ fig.add_trace(go.Scatter(x=series['Date'], y=series['Close'], mode='lines', name='Historical'))
92
+ fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
93
+ else: # Candlestick Graph
94
+ fig = go.Figure(data=[go.Candlestick(x=series['Date'],
95
+ open=series['Open'],
96
+ high=series['High'],
97
+ low=series['Low'],
98
+ close=series['Close'],
99
+ name='Historical')])
100
+ fig.add_trace(go.Scatter(x=forecast['Date'], y=forecast['Forecast'], mode='lines', name='Forecast'))
101
+
102
+ fig.update_layout(title=f"Stock Price of {stock_name}",
103
+ xaxis_title="Date",
104
+ yaxis_title="Price")
105
 
 
 
106
  return fig
107
 
108
+ out = gr.Plot()
109
+ inputs = [d1, d2, d3, d4, d5]
110
+ d2.input(get_stock_graph, inputs, out)
111
+ d3.input(get_stock_graph, inputs, out)
112
+ d4.input(get_stock_graph, inputs, out)
113
+ d5.input(get_stock_graph, inputs, out)
114
 
 
115
  demo.launch()