Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import numpy as np | |
import tensorflow as tf | |
import joblib | |
from sklearn.metrics import mean_absolute_error, mean_squared_error | |
from sklearn.preprocessing import MinMaxScaler | |
# Load the dataset | |
webtraffic_data = pd.read_csv("webtraffic.csv") | |
# Convert 'Hour Index' to datetime | |
start_date = pd.Timestamp("2024-01-01 00:00:00") | |
webtraffic_data['Datetime'] = start_date + pd.to_timedelta(webtraffic_data['Hour Index'], unit='h') | |
webtraffic_data.drop(columns=['Hour Index'], inplace=True) | |
# Split the data into train/test | |
train_size = int(len(webtraffic_data) * 0.8) | |
train_data = webtraffic_data.iloc[:train_size] | |
test_data = webtraffic_data.iloc[train_size:] | |
# Load pre-trained models | |
sarima_model = joblib.load("sarima_model.pkl") # SARIMA model | |
lstm_model = tf.keras.models.load_model("lstm_model.keras") # LSTM model | |
# Initialize scalers and scale the data for LSTM | |
scaler_X = MinMaxScaler(feature_range=(0, 1)) | |
scaler_y = MinMaxScaler(feature_range=(0, 1)) | |
# Scale training data | |
X_train_scaled = scaler_X.fit_transform(train_data['Sessions'].values.reshape(-1, 1)) | |
y_train_scaled = scaler_y.fit_transform(train_data['Sessions'].values.reshape(-1, 1)) | |
# Scale test data | |
X_test_scaled = scaler_X.transform(test_data['Sessions'].values.reshape(-1, 1)) | |
y_test_scaled = scaler_y.transform(test_data['Sessions'].values.reshape(-1, 1)) | |
# Reshape test data for LSTM (samples, time_steps, features) | |
X_test_lstm = X_test_scaled.reshape((X_test_scaled.shape[0], 1, 1)) | |
# Generate predictions for SARIMA | |
future_periods = len(test_data) | |
sarima_predictions = sarima_model.predict(n_periods=future_periods) | |
# Generate predictions for LSTM | |
lstm_predictions_scaled = lstm_model.predict(X_test_lstm[:future_periods]) | |
lstm_predictions = scaler_y.inverse_transform(lstm_predictions_scaled) | |
# Combine predictions into a DataFrame for visualization | |
future_predictions = pd.DataFrame({ | |
"Datetime": test_data['Datetime'], | |
"SARIMA_Predicted": sarima_predictions, | |
"LSTM_Predicted": lstm_predictions.flatten() | |
}) | |
# Calculate metrics | |
mae_sarima_future = mean_absolute_error(test_data['Sessions'], sarima_predictions) | |
rmse_sarima_future = mean_squared_error(test_data['Sessions'], sarima_predictions, squared=False) | |
mae_lstm_future = mean_absolute_error(test_data['Sessions'], lstm_predictions) | |
rmse_lstm_future = mean_squared_error(test_data['Sessions'], lstm_predictions, squared=False) | |
# Function to plot actual vs. predicted traffic | |
def plot_predictions(): | |
plt.figure(figsize=(15, 6)) | |
# Plot actual traffic | |
plt.plot(webtraffic_data['Datetime'].iloc[-future_periods:], | |
test_data['Sessions'].values[-future_periods:], | |
label='Actual Traffic', color='black', linestyle='dotted', linewidth=2) | |
# Plot SARIMA predictions | |
plt.plot(future_predictions['Datetime'], | |
future_predictions['SARIMA_Predicted'], | |
label='SARIMA Predicted', color='blue', linewidth=2) | |
# Plot LSTM predictions | |
plt.plot(future_predictions['Datetime'], | |
future_predictions['LSTM_Predicted'], | |
label='LSTM Predicted', color='green', linewidth=2) | |
plt.title("Future Traffic Predictions: SARIMA vs LSTM", fontsize=16) | |
plt.xlabel("Datetime", fontsize=12) | |
plt.ylabel("Sessions", fontsize=12) | |
plt.legend(loc="upper left") | |
plt.grid(True) | |
plt.tight_layout() | |
# Save the plot to a file | |
plot_path = "/content/predictions_plot.png" | |
plt.savefig(plot_path) | |
plt.close() | |
return plot_path | |
# Function to display prediction metrics | |
def display_metrics(): | |
metrics = { | |
"Model": ["SARIMA", "LSTM"], | |
"Mean Absolute Error (MAE)": [mae_sarima_future, mae_lstm_future], | |
"Root Mean Squared Error (RMSE)": [rmse_sarima_future, rmse_lstm_future] | |
} | |
return pd.DataFrame(metrics) | |
# Gradio function to display the dashboard | |
def gradio_dashboard(): | |
plot_path = plot_predictions() | |
metrics_df = display_metrics() | |
return plot_path, metrics_df.to_string() | |
# Gradio interface | |
with gr.Blocks() as dashboard: | |
gr.Markdown("## Web Traffic Prediction Dashboard") | |
gr.Markdown("This dashboard compares predictions from SARIMA and LSTM models.") | |
# Show the plot | |
plot_output = gr.Image(label="Prediction Plot") | |
metrics_output = gr.Textbox(label="Prediction Metrics", lines=15) | |
# Define the Gradio button and actions | |
gr.Button("Update Dashboard").click(gradio_dashboard, outputs=[plot_output, metrics_output]) | |
# Launch the dashboard | |
dashboard.launch() | |