NTI_ML_Project / app.py
Zeyadd-Mostaffa's picture
Update app.py
3e47c80 verified
raw
history blame
4.5 kB
import gradio as gr
import xgboost as xgb
import numpy as np
import pandas as pd
import joblib
import os
import warnings
import shap
import matplotlib.pyplot as plt
# Suppress XGBoost warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
# Load your model (automatically detect XGBoost or joblib model)
def load_model():
model_path = "best_model.json" # Ensure this matches your file name
if os.path.exists(model_path):
model = xgb.Booster()
model.load_model(model_path)
print("βœ… Model loaded successfully.")
return model
else:
print("❌ Model file not found.")
return None
model = load_model()
# Prediction function with dynamic threshold
def predict_employee_status(satisfaction_level, last_evaluation, number_project,
average_monthly_hours, time_spent_company,
work_accident, promotion_last_5years, salary, department, threshold=0.5):
# One-hot encode the department
departments = [
'RandD', 'accounting', 'hr', 'management', 'marketing',
'product_mng', 'sales', 'support', 'technical'
]
department_features = {f"department_{dept}": 0 for dept in departments}
if department in departments:
department_features[f"department_{department}"] = 1
# Prepare the input with all 17 features as a DataFrame with column names
input_data = {
"satisfaction_level": [satisfaction_level],
"last_evaluation": [last_evaluation],
"number_project": [number_project],
"average_monthly_hours": [average_monthly_hours],
"time_spent_company": [time_spent_company],
"Work_accident": [work_accident],
"promotion_last_5years": [promotion_last_5years],
"salary": [salary],
**department_features
}
input_df = pd.DataFrame(input_data)
# Predict using the model
if model is None:
return "❌ No model found. Please upload the model file."
try:
dmatrix = xgb.DMatrix(input_df)
prediction = model.predict(dmatrix)
prediction_prob = prediction[0]
# Apply the dynamic threshold
result = "βœ… Employee is likely to quit." if prediction_prob >= threshold else "βœ… Employee is likely to stay."
explanation = explain_prediction(input_df)
return f"{result} (Probability: {prediction_prob:.2%})\n\nExplanation:\n{explanation}"
except Exception as e:
return f"❌ Error: {str(e)}"
# SHAP Explainability
def explain_prediction(input_df):
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(input_df)
# Generating SHAP explanation for this prediction
shap.initjs()
plt.figure()
shap.waterfall_plot(shap.Explanation(values=shap_values[0],
base_values=explainer.expected_value,
data=input_df.iloc[0].values,
feature_names=input_df.columns))
plt.savefig("shap_explanation.png")
return "SHAP explanation generated for this prediction."
# Gradio interface with dynamic threshold
def gradio_interface():
interface = gr.Interface(
fn=predict_employee_status,
inputs=[
gr.Number(label="Satisfaction Level (0.0 - 1.0)"),
gr.Number(label="Last Evaluation (0.0 - 1.0)"),
gr.Number(label="Number of Projects (1 - 10)"),
gr.Number(label="Average Monthly Hours (80 - 320)"),
gr.Number(label="Time Spent at Company (Years)"),
gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
gr.Dropdown(
['RandD', 'accounting', 'hr', 'management', 'marketing',
'product_mng', 'sales', 'support', 'technical'],
label="Department"
),
gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
],
outputs="text",
title="Employee Retention Prediction System (With SHAP Explainability)",
description="Predict whether an employee is likely to stay or quit based on their profile. Adjust the threshold for accurate predictions.",
theme="dark"
)
interface.launch()
gradio_interface()