Zeyadd-Mostaffa commited on
Commit
7c5d1d0
Β·
verified Β·
1 Parent(s): 1372365

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -73
app.py CHANGED
@@ -1,73 +1,93 @@
1
- import gradio as gr
2
- import xgboost as xgb
3
- import numpy as np
4
- import joblib
5
- import os
6
- import warnings
7
-
8
- # Suppress XGBoost warnings
9
- warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
10
-
11
-
12
- # Load your model (automatically detect XGBoost or joblib model)
13
- def load_model():
14
- if os.path.exists("models/best_model.json"):
15
- model = xgb.Booster()
16
- model.load_model("models/best_model.json")
17
- print("βœ… Model loaded using XGBoost's native method.")
18
- return model
19
- elif os.path.exists("models/best_model.pkl"):
20
- model = joblib.load("models/best_model.pkl")
21
- print("βœ… Model loaded using Joblib.")
22
- return model
23
- else:
24
- print("❌ No model file found.")
25
- return None
26
-
27
-
28
- model = load_model()
29
-
30
-
31
- # Prediction function
32
- def predict_employee_status(satisfaction_level, last_evaluation, number_project,
33
- average_monthly_hours, time_spend_company,
34
- work_accident, promotion_last_5years, salary):
35
- input_data = np.array([[satisfaction_level, last_evaluation, number_project,
36
- average_monthly_hours, time_spend_company,
37
- work_accident, promotion_last_5years, salary]])
38
-
39
- if model is None:
40
- return "❌ No model found. Please upload the model file."
41
-
42
- if isinstance(model, xgb.Booster):
43
- dmatrix = xgb.DMatrix(input_data)
44
- prediction = model.predict(dmatrix)[0]
45
- result = "βœ… The employee is likely to Quit." if prediction > 0.5 else "βœ… The employee is likely to Stay."
46
- else:
47
- prediction = model.predict(input_data)[0]
48
- result = "βœ… The employee is likely to Quit." if prediction == 1 else "βœ… The employee is likely to Stay."
49
-
50
- return result
51
-
52
-
53
- # Gradio interface with enhanced UI
54
- interface = gr.Interface(
55
- fn=predict_employee_status,
56
- inputs=[
57
- gr.Number(label="Satisfaction Level (0.0 - 1.0)", value=0.5),
58
- gr.Number(label="Last Evaluation (0.0 - 1.0)", value=0.6),
59
- gr.Number(label="Number of Projects (1 - 10)", value=3),
60
- gr.Number(label="Average Monthly Hours (80 - 320)", value=150),
61
- gr.Number(label="Time Spent at Company (Years)", value=3),
62
- gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
63
- gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
64
- gr.Dropdown(choices=[0, 1, 2], label="Salary Level (0 = Low, 1 = Medium, 2 = High)")
65
- ],
66
- outputs="text",
67
- title="πŸš€ Employee Retention Prediction System",
68
- description="Predict whether an employee is likely to stay or quit based on their profile.",
69
- live=False
70
- )
71
-
72
- # Launch Gradio app
73
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import xgboost as xgb
3
+ import numpy as np
4
+ import joblib
5
+ import os
6
+ import warnings
7
+
8
+ # Suppress XGBoost warnings
9
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
10
+
11
+ # Load your model (automatically detect XGBoost or joblib model)
12
+ def load_model():
13
+ if os.path.exists("best_model.json"): # Model in root directory
14
+ model = xgb.Booster()
15
+ model.load_model("best_model.json")
16
+ print("βœ… Model loaded using XGBoost's native method.")
17
+ return model
18
+ elif os.path.exists("best_model.pkl"): # Joblib model in root directory
19
+ model = joblib.load("best_model.pkl")
20
+ print("βœ… Model loaded using Joblib.")
21
+ return model
22
+ else:
23
+ print("❌ No model file found.")
24
+ return None
25
+
26
+ model = load_model()
27
+
28
+ # Prediction function
29
+ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
30
+ average_monthly_hours, time_spend_company,
31
+ work_accident, promotion_last_5years, salary, department):
32
+
33
+ # Encode the department as numeric (One-Hot Encoding or Label Encoding)
34
+ department_mapping = {
35
+ "Sales": 0,
36
+ "Technical": 1,
37
+ "Support": 2,
38
+ "IT": 3,
39
+ "Management": 4,
40
+ "Product Management": 5,
41
+ "Marketing": 6,
42
+ "HR": 7,
43
+ "Accounting": 8,
44
+ "R&D": 9
45
+ }
46
+
47
+ # Convert the department to a numeric value
48
+ department_encoded = department_mapping.get(department, 0)
49
+
50
+ # Prepare input data including the department
51
+ input_data = np.array([[satisfaction_level, last_evaluation, number_project,
52
+ average_monthly_hours, time_spend_company,
53
+ work_accident, promotion_last_5years, salary, department_encoded]])
54
+
55
+ if model is None:
56
+ return "❌ No model found. Please upload the model file."
57
+
58
+ if isinstance(model, xgb.Booster):
59
+ dmatrix = xgb.DMatrix(input_data)
60
+ prediction = model.predict(dmatrix)[0]
61
+ result = "βœ… The employee is likely to Quit." if prediction > 0.5 else "βœ… The employee is likely to Stay."
62
+ else:
63
+ prediction = model.predict(input_data)[0]
64
+ result = "βœ… The employee is likely to Quit." if prediction == 1 else "βœ… The employee is likely to Stay."
65
+
66
+ return result
67
+
68
+ # Gradio interface with enhanced UI including Department
69
+ interface = gr.Interface(
70
+ fn=predict_employee_status,
71
+ inputs=[
72
+ gr.Number(label="Satisfaction Level (0.0 - 1.0)", value=0.5),
73
+ gr.Number(label="Last Evaluation (0.0 - 1.0)", value=0.6),
74
+ gr.Number(label="Number of Projects (1 - 10)", value=3),
75
+ gr.Number(label="Average Monthly Hours (80 - 320)", value=150),
76
+ gr.Number(label="Time Spent at Company (Years)", value=3),
77
+ gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
78
+ gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
79
+ gr.Dropdown(choices=[0, 1, 2], label="Salary Level (0 = Low, 1 = Medium, 2 = High)"),
80
+ gr.Dropdown(
81
+ choices=["Sales", "Technical", "Support", "IT", "Management",
82
+ "Product Management", "Marketing", "HR", "Accounting", "R&D"],
83
+ label="Department"
84
+ )
85
+ ],
86
+ outputs="text",
87
+ title="πŸš€ Employee Retention Prediction System",
88
+ description="Predict whether an employee is likely to stay or quit based on their profile.",
89
+ live=False
90
+ )
91
+
92
+ # Launch Gradio app
93
+ interface.launch()