NTI_ML_Project / app.py
Zeyadd-Mostaffa's picture
Update app.py
7c5d1d0 verified
raw
history blame
3.52 kB
import gradio as gr
import xgboost as xgb
import numpy as np
import joblib
import os
import warnings
# Suppress XGBoost warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
# Load your model (automatically detect XGBoost or joblib model)
def load_model():
if os.path.exists("best_model.json"): # Model in root directory
model = xgb.Booster()
model.load_model("best_model.json")
print("βœ… Model loaded using XGBoost's native method.")
return model
elif os.path.exists("best_model.pkl"): # Joblib model in root directory
model = joblib.load("best_model.pkl")
print("βœ… Model loaded using Joblib.")
return model
else:
print("❌ No model file found.")
return None
model = load_model()
# Prediction function
def predict_employee_status(satisfaction_level, last_evaluation, number_project,
average_monthly_hours, time_spend_company,
work_accident, promotion_last_5years, salary, department):
# Encode the department as numeric (One-Hot Encoding or Label Encoding)
department_mapping = {
"Sales": 0,
"Technical": 1,
"Support": 2,
"IT": 3,
"Management": 4,
"Product Management": 5,
"Marketing": 6,
"HR": 7,
"Accounting": 8,
"R&D": 9
}
# Convert the department to a numeric value
department_encoded = department_mapping.get(department, 0)
# Prepare input data including the department
input_data = np.array([[satisfaction_level, last_evaluation, number_project,
average_monthly_hours, time_spend_company,
work_accident, promotion_last_5years, salary, department_encoded]])
if model is None:
return "❌ No model found. Please upload the model file."
if isinstance(model, xgb.Booster):
dmatrix = xgb.DMatrix(input_data)
prediction = model.predict(dmatrix)[0]
result = "βœ… The employee is likely to Quit." if prediction > 0.5 else "βœ… The employee is likely to Stay."
else:
prediction = model.predict(input_data)[0]
result = "βœ… The employee is likely to Quit." if prediction == 1 else "βœ… The employee is likely to Stay."
return result
# Gradio interface with enhanced UI including Department
interface = gr.Interface(
fn=predict_employee_status,
inputs=[
gr.Number(label="Satisfaction Level (0.0 - 1.0)", value=0.5),
gr.Number(label="Last Evaluation (0.0 - 1.0)", value=0.6),
gr.Number(label="Number of Projects (1 - 10)", value=3),
gr.Number(label="Average Monthly Hours (80 - 320)", value=150),
gr.Number(label="Time Spent at Company (Years)", value=3),
gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
gr.Dropdown(choices=[0, 1, 2], label="Salary Level (0 = Low, 1 = Medium, 2 = High)"),
gr.Dropdown(
choices=["Sales", "Technical", "Support", "IT", "Management",
"Product Management", "Marketing", "HR", "Accounting", "R&D"],
label="Department"
)
],
outputs="text",
title="πŸš€ Employee Retention Prediction System",
description="Predict whether an employee is likely to stay or quit based on their profile.",
live=False
)
# Launch Gradio app
interface.launch()