Spaces:
Sleeping
Sleeping
import gradio as gr | |
import matplotlib.pyplot as plt | |
import pandas as pd | |
import joblib | |
# Load the dataset | |
data_file = "webtraffic.csv" | |
webtraffic_data = pd.read_csv(data_file) | |
# Verify if 'Datetime' exists, or create it | |
if "Datetime" not in webtraffic_data.columns: | |
print("Datetime column missing. Attempting to create from 'Hour Index'.") | |
start_date = pd.Timestamp("2024-01-01 00:00:00") | |
webtraffic_data["Datetime"] = start_date + pd.to_timedelta( | |
webtraffic_data["Hour Index"], unit="h" | |
) | |
else: | |
webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"]) | |
# Ensure 'Datetime' column is sorted | |
webtraffic_data.sort_values("Datetime", inplace=True) | |
# Load the SARIMA model | |
sarima_model = joblib.load("sarima_model.pkl") | |
# Define future periods for evaluation | |
future_periods = 48 | |
# Dummy values for metrics (if needed) | |
mae_sarima_future = 100 | |
rmse_sarima_future = 150 | |
# Function to generate plot based on SARIMA model | |
def generate_plot(): | |
future_dates = pd.date_range( | |
start=webtraffic_data["Datetime"].iloc[-1], periods=future_periods + 1, freq="H" | |
)[1:] | |
sarima_predictions = sarima_model.predict(n_periods=future_periods) | |
future_predictions = pd.DataFrame( | |
{"Datetime": future_dates, "SARIMA_Predicted": sarima_predictions} | |
) | |
plt.figure(figsize=(15, 6)) | |
plt.plot( | |
webtraffic_data["Datetime"], | |
webtraffic_data["Sessions"], | |
label="Actual Traffic", | |
color="black", | |
linestyle="dotted", | |
linewidth=2, | |
) | |
plt.plot( | |
future_predictions["Datetime"], | |
future_predictions["SARIMA_Predicted"], | |
label="SARIMA Predicted", | |
color="blue", | |
linewidth=2, | |
) | |
plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16) | |
plt.xlabel("Datetime", fontsize=12) | |
plt.ylabel("Sessions", fontsize=12) | |
plt.legend(loc="upper left") | |
plt.grid(True) | |
plt.tight_layout() | |
plot_path = "sarima_prediction_plot.png" | |
plt.savefig(plot_path) | |
plt.close() | |
return plot_path | |
# Function to display SARIMA metrics | |
def display_metrics(): | |
metrics = { | |
"Model": ["SARIMA"], | |
"Mean Absolute Error (MAE)": [mae_sarima_future], | |
"Root Mean Squared Error (RMSE)": [rmse_sarima_future], | |
} | |
return pd.DataFrame(metrics) | |
# Gradio interface function | |
def dashboard_interface(): | |
plot_path = generate_plot() | |
metrics_df = display_metrics() | |
return plot_path, metrics_df.to_string() | |
# Build the Gradio interface | |
with gr.Blocks() as dashboard: | |
gr.Markdown("## Interactive SARIMA Web Traffic Prediction Dashboard") | |
gr.Markdown( | |
"This dashboard shows SARIMA model predictions vs actual traffic along with performance metrics." | |
) | |
plot_output = gr.Image(label="Prediction Plot") | |
metrics_output = gr.Textbox(label="Metrics", lines=15) | |
gr.Button("Generate Predictions").click( | |
fn=dashboard_interface, | |
inputs=[], | |
outputs=[plot_output, metrics_output], | |
) | |
# Launch the Gradio dashboard | |
if __name__ == "__main__": | |
dashboard.launch() | |