webtraffic / app.py
manjunathainti's picture
Update app.py
f8d0f44 verified
raw
history blame
3.11 kB
import gradio as gr
import matplotlib.pyplot as plt
import pandas as pd
import joblib
# Load the dataset
data_file = "webtraffic.csv"
webtraffic_data = pd.read_csv(data_file)
# Verify if 'Datetime' exists, or create it
if "Datetime" not in webtraffic_data.columns:
print("Datetime column missing. Attempting to create from 'Hour Index'.")
start_date = pd.Timestamp("2024-01-01 00:00:00")
webtraffic_data["Datetime"] = start_date + pd.to_timedelta(
webtraffic_data["Hour Index"], unit="h"
)
else:
webtraffic_data["Datetime"] = pd.to_datetime(webtraffic_data["Datetime"])
# Ensure 'Datetime' column is sorted
webtraffic_data.sort_values("Datetime", inplace=True)
# Load the SARIMA model
sarima_model = joblib.load("sarima_model.pkl")
# Define future periods for evaluation
future_periods = 48
# Dummy values for metrics (if needed)
mae_sarima_future = 100
rmse_sarima_future = 150
# Function to generate plot based on SARIMA model
def generate_plot():
future_dates = pd.date_range(
start=webtraffic_data["Datetime"].iloc[-1], periods=future_periods + 1, freq="H"
)[1:]
sarima_predictions = sarima_model.predict(n_periods=future_periods)
future_predictions = pd.DataFrame(
{"Datetime": future_dates, "SARIMA_Predicted": sarima_predictions}
)
plt.figure(figsize=(15, 6))
plt.plot(
webtraffic_data["Datetime"],
webtraffic_data["Sessions"],
label="Actual Traffic",
color="black",
linestyle="dotted",
linewidth=2,
)
plt.plot(
future_predictions["Datetime"],
future_predictions["SARIMA_Predicted"],
label="SARIMA Predicted",
color="blue",
linewidth=2,
)
plt.title("SARIMA Predictions vs Actual Traffic", fontsize=16)
plt.xlabel("Datetime", fontsize=12)
plt.ylabel("Sessions", fontsize=12)
plt.legend(loc="upper left")
plt.grid(True)
plt.tight_layout()
plot_path = "sarima_prediction_plot.png"
plt.savefig(plot_path)
plt.close()
return plot_path
# Function to display SARIMA metrics
def display_metrics():
metrics = {
"Model": ["SARIMA"],
"Mean Absolute Error (MAE)": [mae_sarima_future],
"Root Mean Squared Error (RMSE)": [rmse_sarima_future],
}
return pd.DataFrame(metrics)
# Gradio interface function
def dashboard_interface():
plot_path = generate_plot()
metrics_df = display_metrics()
return plot_path, metrics_df.to_string()
# Build the Gradio interface
with gr.Blocks() as dashboard:
gr.Markdown("## Interactive SARIMA Web Traffic Prediction Dashboard")
gr.Markdown(
"This dashboard shows SARIMA model predictions vs actual traffic along with performance metrics."
)
plot_output = gr.Image(label="Prediction Plot")
metrics_output = gr.Textbox(label="Metrics", lines=15)
gr.Button("Generate Predictions").click(
fn=dashboard_interface,
inputs=[],
outputs=[plot_output, metrics_output],
)
# Launch the Gradio dashboard
if __name__ == "__main__":
dashboard.launch()