trading-master / app.py
shaheerawan3's picture
Update app.py
673e0a8 verified
raw
history blame
7.3 kB
import streamlit as st
from datetime import date, datetime
import yfinance as yf
import pandas as pd
import numpy as np
from prophet import Prophet
from prophet.plot import plot_plotly
import plotly.graph_objects as go
from sklearn.metrics import mean_absolute_error, mean_squared_error
import plotly.express as px
# Configure Streamlit page settings
st.set_page_config(
page_title="Stock & Crypto Forecast",
page_icon="๐Ÿ“ˆ",
layout="wide"
)
# Constants and configurations
START = "2015-01-01"
TODAY = date.today().strftime("%Y-%m-%d")
# Custom CSS for better styling
st.markdown("""
<style>
.stButton>button {
width: 100%;
}
.reportview-container {
background: #f0f2f6
}
</style>
""", unsafe_allow_html=True)
class AssetPredictor:
def __init__(self):
self.assets = {
'Stocks': ['GOOG', 'AAPL', 'MSFT', 'GME'],
'Cryptocurrencies': ['BTC-USD', 'ETH-USD', 'DOGE-USD', 'ADA-USD']
}
@st.cache_data(ttl=3600) # Cache data for 1 hour
def load_data(self, ticker):
"""Load and validate financial data."""
try:
data = yf.download(ticker, START, TODAY)
if data.empty:
raise ValueError(f"No data found for {ticker}")
data.reset_index(inplace=True)
# Ensure all required columns exist and are numeric
required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume']
for col in required_columns:
if col not in data.columns:
raise ValueError(f"Missing required column: {col}")
if col != 'Date':
data[col] = pd.to_numeric(data[col], errors='coerce')
data.dropna(inplace=True)
return data
except Exception as e:
st.error(f"Error loading data: {str(e)}")
return None
def prepare_prophet_data(self, data):
"""Prepare data for Prophet model."""
df_prophet = data[['Date', 'Close']].copy()
df_prophet.columns = ['ds', 'y']
return df_prophet
def train_prophet_model(self, data, period):
"""Train and return Prophet model with customized parameters."""
model = Prophet(
yearly_seasonality=True,
weekly_seasonality=True,
daily_seasonality=True,
changepoint_prior_scale=0.05,
seasonality_prior_scale=10.0,
changepoint_range=0.9
)
# Add custom seasonalities
model.add_seasonality(
name='monthly',
period=30.5,
fourier_order=5
)
model.fit(data)
future = model.make_future_dataframe(periods=period)
return model, future
def main():
predictor = AssetPredictor()
# Sidebar for user inputs
st.sidebar.title("โš™๏ธ Configuration")
asset_type = st.sidebar.radio("Select Asset Type", list(predictor.assets.keys()))
selected_asset = st.sidebar.selectbox(
'Select Asset',
predictor.assets[asset_type]
)
# Main content
st.title('๐Ÿ“ˆ Advanced Stock & Cryptocurrency Forecast')
# Time range selection
col1, col2 = st.columns(2)
with col1:
n_years = st.slider('Forecast Period (Years):', 1, 4)
with col2:
confidence_level = st.slider('Confidence Level:', 0.8, 0.99, 0.95)
period = n_years * 365
# Load and process data
with st.spinner('Loading data...'):
data = predictor.load_data(selected_asset)
if data is not None:
# Display technical indicators
st.subheader('๐Ÿ“Š Technical Analysis')
# Calculate technical indicators
data['SMA_20'] = data['Close'].rolling(window=20).mean()
data['SMA_50'] = data['Close'].rolling(window=50).mean()
data['RSI'] = calculate_rsi(data['Close'])
# Create technical analysis plot
fig_technical = go.Figure()
fig_technical.add_trace(go.Candlestick(
x=data['Date'],
open=data['Open'],
high=data['High'],
low=data['Low'],
close=data['Close'],
name='Price'
))
fig_technical.add_trace(go.Scatter(
x=data['Date'],
y=data['SMA_20'],
name='SMA 20',
line=dict(color='orange')
))
fig_technical.add_trace(go.Scatter(
x=data['Date'],
y=data['SMA_50'],
name='SMA 50',
line=dict(color='blue')
))
fig_technical.update_layout(
title=f'{selected_asset} Technical Analysis',
yaxis_title='Price',
template='plotly_dark'
)
st.plotly_chart(fig_technical, use_container_width=True)
# Prepare and train Prophet model
df_prophet = predictor.prepare_prophet_data(data)
try:
model, future = predictor.train_prophet_model(df_prophet, period)
forecast = model.predict(future)
# Calculate performance metrics
historical_predictions = forecast[forecast['ds'].isin(df_prophet['ds'])]
mae = mean_absolute_error(df_prophet['y'], historical_predictions['yhat'])
rmse = np.sqrt(mean_squared_error(df_prophet['y'], historical_predictions['yhat']))
mape = np.mean(np.abs((df_prophet['y'] - historical_predictions['yhat']) / df_prophet['y'])) * 100
# Display metrics in columns
st.subheader('๐Ÿ“‰ Model Performance Metrics')
col1, col2, col3 = st.columns(3)
col1.metric("MAE", f"${mae:.2f}")
col2.metric("RMSE", f"${rmse:.2f}")
col3.metric("MAPE", f"{mape:.2f}%")
# Forecast visualization
st.subheader('๐Ÿ”ฎ Price Forecast')
fig_forecast = plot_plotly(model, forecast)
fig_forecast.update_layout(template='plotly_dark')
st.plotly_chart(fig_forecast, use_container_width=True)
# Show forecast components
st.subheader("๐Ÿ“Š Forecast Components")
fig_components = model.plot_components(forecast)
st.plotly_chart(fig_components, use_container_width=True)
# Download forecast data
csv = convert_df_to_csv(forecast)
st.download_button(
label="Download Forecast Data",
data=csv,
file_name=f'{selected_asset}_forecast.csv',
mime='text/csv'
)
except Exception as e:
st.error(f"Error in prediction: {str(e)}")
def calculate_rsi(prices, period=14):
"""Calculate Relative Strength Index."""
delta = prices.diff()
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean()
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean()
rs = gain / loss
return 100 - (100 / (1 + rs))
@st.cache_data
def convert_df_to_csv(df):
"""Convert dataframe to CSV for download."""
return df.to_csv(index=False).encode('utf-8')
if __name__ == "__main__":
main()