import gradio as gr import xgboost as xgb import numpy as np import pandas as pd import joblib import os import warnings import shap import matplotlib.pyplot as plt from sklearn.metrics import roc_curve, precision_recall_curve from imblearn.over_sampling import SMOTE # Suppress XGBoost warnings warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*") # Load your model (automatically detect XGBoost or joblib model) def load_model(): model_path = "xgboost_model.json" # Ensure this matches your file name if os.path.exists(model_path): model = xgb.Booster() model.load_model(model_path) print("✅ Model loaded successfully.") return model else: print("❌ Model file not found.") return None model = load_model() # Prediction function with dynamic threshold and balanced data def predict_employee_status(satisfaction_level, last_evaluation, number_project, average_monthly_hours, time_spent_company, work_accident, promotion_last_5years, salary, department, threshold=0.5): # One-hot encode the department departments = [ 'RandD', 'accounting', 'hr', 'management', 'marketing', 'product_mng', 'sales', 'support', 'technical' ] department_features = {f"department_{dept}": 0 for dept in departments} if department in departments: department_features[f"department_{department}"] = 1 # Automatically Generate Interaction Features satisfaction_evaluation = satisfaction_level * last_evaluation work_balance = average_monthly_hours / number_project # Prepare the input with all expected features as a DataFrame with column names input_data = { "satisfaction_level": [satisfaction_level], "last_evaluation": [last_evaluation], "number_project": [number_project], "average_monthly_hours": [average_monthly_hours], "time_spent_company": [time_spent_company], "Work_accident": [work_accident], "promotion_last_5years": [promotion_last_5years], "salary": [salary], "satisfaction_evaluation": [satisfaction_evaluation], "work_balance": [work_balance], **department_features } input_df = pd.DataFrame(input_data) # Predict using the model if model is None: return "❌ No model found. Please upload the model file." try: dmatrix = xgb.DMatrix(input_df) prediction = model.predict(dmatrix) prediction_prob = prediction[0] # Apply the dynamic threshold result = "✅ Employee is likely to quit." if prediction_prob >= threshold else "✅ Employee is likely to stay." explanation = explain_prediction(input_df) return f"{result} (Probability: {prediction_prob:.2%})\n\nExplanation:\n{explanation}" except Exception as e: return f"❌ Error: {str(e)}" # SHAP Explainability (Directly Integrated) def explain_prediction(input_df): try: explainer = shap.TreeExplainer(model) shap_values = explainer.shap_values(input_df) # Generate and save SHAP explanation as an image shap.initjs() plt.figure() shap.waterfall_plot(shap.Explanation(values=shap_values[0], base_values=explainer.expected_value, data=input_df.iloc[0].values, feature_names=input_df.columns)) plt.savefig("shap_explanation.png") return "SHAP explanation generated for this prediction." except Exception as e: return f"❌ Error in SHAP: {str(e)}" # Gradio interface with dynamic threshold and SHAP def gradio_interface(): interface = gr.Interface( fn=predict_employee_status, inputs=[ gr.Number(label="Satisfaction Level (0.0 - 1.0)"), gr.Number(label="Last Evaluation (0.0 - 1.0)"), gr.Number(label="Number of Projects (1 - 10)"), gr.Number(label="Average Monthly Hours (80 - 320)"), gr.Number(label="Time Spent at Company (Years)"), gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"), gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"), gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"), gr.Dropdown( ['RandD', 'accounting', 'hr', 'management', 'marketing', 'product_mng', 'sales', 'support', 'technical'], label="Department" ), gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold") ], outputs="text", title="Employee Retention Prediction System (With SHAP & ROC Threshold)", description="Predict whether an employee is likely to stay or quit based on their profile. Adjust the threshold for accurate predictions.", theme="dark" ) interface.launch() gradio_interface()