Spaces:
Running
Running
import streamlit as st | |
from datetime import date, datetime | |
import yfinance as yf | |
import pandas as pd | |
import numpy as np | |
from prophet import Prophet | |
from prophet.plot import plot_plotly | |
import plotly.graph_objects as go | |
from sklearn.metrics import mean_absolute_error, mean_squared_error | |
import plotly.express as px | |
# Configure Streamlit page settings | |
st.set_page_config( | |
page_title="Stock & Crypto Forecast", | |
page_icon="๐", | |
layout="wide" | |
) | |
# Constants and configurations | |
START = "2015-01-01" | |
TODAY = date.today().strftime("%Y-%m-%d") | |
# Custom CSS for better styling | |
st.markdown(""" | |
<style> | |
.stButton>button { | |
width: 100%; | |
} | |
.reportview-container { | |
background: #f0f2f6 | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
class AssetPredictor: | |
def __init__(self): | |
self.assets = { | |
'Stocks': ['GOOG', 'AAPL', 'MSFT', 'GME'], | |
'Cryptocurrencies': ['BTC-USD', 'ETH-USD', 'DOGE-USD', 'ADA-USD'] | |
} | |
# Cache data for 1 hour | |
def load_data(self, ticker): | |
"""Load and validate financial data.""" | |
try: | |
data = yf.download(ticker, START, TODAY) | |
if data.empty: | |
raise ValueError(f"No data found for {ticker}") | |
data.reset_index(inplace=True) | |
# Ensure all required columns exist and are numeric | |
required_columns = ['Date', 'Open', 'High', 'Low', 'Close', 'Volume'] | |
for col in required_columns: | |
if col not in data.columns: | |
raise ValueError(f"Missing required column: {col}") | |
if col != 'Date': | |
data[col] = pd.to_numeric(data[col], errors='coerce') | |
data.dropna(inplace=True) | |
return data | |
except Exception as e: | |
st.error(f"Error loading data: {str(e)}") | |
return None | |
def prepare_prophet_data(self, data): | |
"""Prepare data for Prophet model.""" | |
df_prophet = data[['Date', 'Close']].copy() | |
df_prophet.columns = ['ds', 'y'] | |
return df_prophet | |
def train_prophet_model(self, data, period): | |
"""Train and return Prophet model with customized parameters.""" | |
model = Prophet( | |
yearly_seasonality=True, | |
weekly_seasonality=True, | |
daily_seasonality=True, | |
changepoint_prior_scale=0.05, | |
seasonality_prior_scale=10.0, | |
changepoint_range=0.9 | |
) | |
# Add custom seasonalities | |
model.add_seasonality( | |
name='monthly', | |
period=30.5, | |
fourier_order=5 | |
) | |
model.fit(data) | |
future = model.make_future_dataframe(periods=period) | |
return model, future | |
def main(): | |
predictor = AssetPredictor() | |
# Sidebar for user inputs | |
st.sidebar.title("โ๏ธ Configuration") | |
asset_type = st.sidebar.radio("Select Asset Type", list(predictor.assets.keys())) | |
selected_asset = st.sidebar.selectbox( | |
'Select Asset', | |
predictor.assets[asset_type] | |
) | |
# Main content | |
st.title('๐ Advanced Stock & Cryptocurrency Forecast') | |
# Time range selection | |
col1, col2 = st.columns(2) | |
with col1: | |
n_years = st.slider('Forecast Period (Years):', 1, 4) | |
with col2: | |
confidence_level = st.slider('Confidence Level:', 0.8, 0.99, 0.95) | |
period = n_years * 365 | |
# Load and process data | |
with st.spinner('Loading data...'): | |
data = predictor.load_data(selected_asset) | |
if data is not None: | |
# Display technical indicators | |
st.subheader('๐ Technical Analysis') | |
# Calculate technical indicators | |
data['SMA_20'] = data['Close'].rolling(window=20).mean() | |
data['SMA_50'] = data['Close'].rolling(window=50).mean() | |
data['RSI'] = calculate_rsi(data['Close']) | |
# Create technical analysis plot | |
fig_technical = go.Figure() | |
fig_technical.add_trace(go.Candlestick( | |
x=data['Date'], | |
open=data['Open'], | |
high=data['High'], | |
low=data['Low'], | |
close=data['Close'], | |
name='Price' | |
)) | |
fig_technical.add_trace(go.Scatter( | |
x=data['Date'], | |
y=data['SMA_20'], | |
name='SMA 20', | |
line=dict(color='orange') | |
)) | |
fig_technical.add_trace(go.Scatter( | |
x=data['Date'], | |
y=data['SMA_50'], | |
name='SMA 50', | |
line=dict(color='blue') | |
)) | |
fig_technical.update_layout( | |
title=f'{selected_asset} Technical Analysis', | |
yaxis_title='Price', | |
template='plotly_dark' | |
) | |
st.plotly_chart(fig_technical, use_container_width=True) | |
# Prepare and train Prophet model | |
df_prophet = predictor.prepare_prophet_data(data) | |
try: | |
model, future = predictor.train_prophet_model(df_prophet, period) | |
forecast = model.predict(future) | |
# Calculate performance metrics | |
historical_predictions = forecast[forecast['ds'].isin(df_prophet['ds'])] | |
mae = mean_absolute_error(df_prophet['y'], historical_predictions['yhat']) | |
rmse = np.sqrt(mean_squared_error(df_prophet['y'], historical_predictions['yhat'])) | |
mape = np.mean(np.abs((df_prophet['y'] - historical_predictions['yhat']) / df_prophet['y'])) * 100 | |
# Display metrics in columns | |
st.subheader('๐ Model Performance Metrics') | |
col1, col2, col3 = st.columns(3) | |
col1.metric("MAE", f"${mae:.2f}") | |
col2.metric("RMSE", f"${rmse:.2f}") | |
col3.metric("MAPE", f"{mape:.2f}%") | |
# Forecast visualization | |
st.subheader('๐ฎ Price Forecast') | |
fig_forecast = plot_plotly(model, forecast) | |
fig_forecast.update_layout(template='plotly_dark') | |
st.plotly_chart(fig_forecast, use_container_width=True) | |
# Show forecast components | |
st.subheader("๐ Forecast Components") | |
fig_components = model.plot_components(forecast) | |
st.plotly_chart(fig_components, use_container_width=True) | |
# Download forecast data | |
csv = convert_df_to_csv(forecast) | |
st.download_button( | |
label="Download Forecast Data", | |
data=csv, | |
file_name=f'{selected_asset}_forecast.csv', | |
mime='text/csv' | |
) | |
except Exception as e: | |
st.error(f"Error in prediction: {str(e)}") | |
def calculate_rsi(prices, period=14): | |
"""Calculate Relative Strength Index.""" | |
delta = prices.diff() | |
gain = (delta.where(delta > 0, 0)).rolling(window=period).mean() | |
loss = (-delta.where(delta < 0, 0)).rolling(window=period).mean() | |
rs = gain / loss | |
return 100 - (100 / (1 + rs)) | |
def convert_df_to_csv(df): | |
"""Convert dataframe to CSV for download.""" | |
return df.to_csv(index=False).encode('utf-8') | |
if __name__ == "__main__": | |
main() |