Zeyadd-Mostaffa commited on
Commit
96a2bd1
Β·
verified Β·
1 Parent(s): 47a7c4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -21
app.py CHANGED
@@ -5,8 +5,6 @@ import joblib
5
  import os
6
  import warnings
7
  from huggingface_hub import hf_hub_download
8
- import xgboost
9
-
10
 
11
  # Suppress warnings
12
  warnings.filterwarnings("ignore")
@@ -23,20 +21,27 @@ def load_model():
23
 
24
  model = load_model()
25
 
26
- # Prediction function
27
- def predict_employee_status(satisfaction_level, last_evaluation, number_project,
28
- average_monthly_hours, time_spend_company,
29
- work_accident, promotion_last_5years, salary, department, threshold=0.5):
30
-
31
- departments = ['RandD', 'accounting', 'hr', 'management', 'marketing',
32
- 'product_mng', 'sales', 'support', 'technical']
 
 
 
 
 
33
  department_features = {f"department_{dept}": 0 for dept in departments}
34
  if department in departments:
35
  department_features[f"department_{department}"] = 1
36
 
 
37
  satisfaction_evaluation = satisfaction_level * last_evaluation
38
  work_balance = average_monthly_hours / number_project
39
 
 
40
  input_data = {
41
  "satisfaction_level": [satisfaction_level],
42
  "last_evaluation": [last_evaluation],
@@ -52,13 +57,17 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
52
  }
53
 
54
  input_df = pd.DataFrame(input_data)
55
- prediction_prob = model.predict_proba(input_df)[0][1]
56
- result = "βœ… Employee is likely to quit." if prediction_prob >= threshold else "βœ… Employee is likely to stay."
57
- return f"{result} (Probability: {prediction_prob:.2%})"
58
 
59
- # Launch Gradio UI
 
 
 
 
 
 
 
60
  def gradio_interface():
61
- gr.Interface(
62
  fn=predict_employee_status,
63
  inputs=[
64
  gr.Number(label="Satisfaction Level (0.0 - 1.0)"),
@@ -69,16 +78,18 @@ def gradio_interface():
69
  gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
70
  gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
71
  gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
72
- gr.Dropdown(['RandD', 'accounting', 'hr', 'management', 'marketing',
73
- 'product_mng', 'sales', 'support', 'technical'], label="Department"),
 
 
 
74
  gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
75
  ],
76
  outputs="text",
77
- title="Employee Retention Prediction System (Ensemble from Hugging Face Hub)",
78
- description="Predict whether an employee is likely to stay or quit based on their profile. Adjust the threshold for accurate predictions.",
79
  theme="dark"
80
- ).launch()
 
81
 
82
  gradio_interface()
83
-
84
-
 
5
  import os
6
  import warnings
7
  from huggingface_hub import hf_hub_download
 
 
8
 
9
  # Suppress warnings
10
  warnings.filterwarnings("ignore")
 
21
 
22
  model = load_model()
23
 
24
+ # Define prediction function
25
+ def predict_employee_status(
26
+ satisfaction_level, last_evaluation, number_project,
27
+ average_monthly_hours, time_spend_company,
28
+ work_accident, promotion_last_5years, salary, department, threshold=0.5
29
+ ):
30
+ departments = [
31
+ 'sales', 'accounting', 'hr', 'technical', 'support',
32
+ 'management', 'IT', 'product_mng', 'marketing', 'RandD'
33
+ ]
34
+
35
+ # One-hot encode department (include department_IT explicitly)
36
  department_features = {f"department_{dept}": 0 for dept in departments}
37
  if department in departments:
38
  department_features[f"department_{department}"] = 1
39
 
40
+ # Interaction features
41
  satisfaction_evaluation = satisfaction_level * last_evaluation
42
  work_balance = average_monthly_hours / number_project
43
 
44
+ # Construct input DataFrame
45
  input_data = {
46
  "satisfaction_level": [satisfaction_level],
47
  "last_evaluation": [last_evaluation],
 
57
  }
58
 
59
  input_df = pd.DataFrame(input_data)
 
 
 
60
 
61
+ try:
62
+ prob = model.predict_proba(input_df)[0][1]
63
+ result = "βœ… Employee is likely to quit." if prob >= threshold else "βœ… Employee is likely to stay."
64
+ return f"{result} (Probability: {prob:.2%})"
65
+ except Exception as e:
66
+ return f"❌ Prediction error: {str(e)}"
67
+
68
+ # Gradio Interface
69
  def gradio_interface():
70
+ interface = gr.Interface(
71
  fn=predict_employee_status,
72
  inputs=[
73
  gr.Number(label="Satisfaction Level (0.0 - 1.0)"),
 
78
  gr.Radio([0, 1], label="Work Accident (0 = No, 1 = Yes)"),
79
  gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
80
  gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
81
+ gr.Dropdown(
82
+ ['sales', 'accounting', 'hr', 'technical', 'support',
83
+ 'management', 'IT', 'product_mng', 'marketing', 'RandD'],
84
+ label="Department"
85
+ ),
86
  gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
87
  ],
88
  outputs="text",
89
+ title="Employee Retention Prediction System (Voting Ensemble)",
90
+ description="Predict whether an employee is likely to stay or quit based on their profile. Supports threshold adjustment.",
91
  theme="dark"
92
+ )
93
+ interface.launch()
94
 
95
  gradio_interface()