Zeyadd-Mostaffa commited on
Commit
2ab8b05
·
verified ·
1 Parent(s): cdda9e4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -8,6 +8,7 @@ import warnings
8
  import shap
9
  import matplotlib.pyplot as plt
10
  from sklearn.metrics import roc_curve, precision_recall_curve
 
11
 
12
  # Suppress XGBoost warnings
13
  warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
@@ -26,17 +27,7 @@ def load_model():
26
 
27
  model = load_model()
28
 
29
- # Automatically find the best threshold using ROC
30
- def optimize_threshold(X_test, y_test):
31
- dmatrix = xgb.DMatrix(X_test)
32
- y_prob = model.predict(dmatrix)
33
-
34
- fpr, tpr, thresholds = roc_curve(y_test, y_prob)
35
- optimal_idx = np.argmax(tpr - fpr)
36
- optimal_threshold = thresholds[optimal_idx]
37
- return optimal_threshold
38
-
39
- # Prediction function with dynamic threshold
40
  def predict_employee_status(satisfaction_level, last_evaluation, number_project,
41
  average_monthly_hours, time_spent_company,
42
  work_accident, promotion_last_5years, salary, department, threshold=0.5):
@@ -50,7 +41,11 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
50
  if department in departments:
51
  department_features[f"department_{department}"] = 1
52
 
53
- # Prepare the input with all 17 features as a DataFrame with column names
 
 
 
 
54
  input_data = {
55
  "satisfaction_level": [satisfaction_level],
56
  "last_evaluation": [last_evaluation],
@@ -60,6 +55,8 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
60
  "Work_accident": [work_accident],
61
  "promotion_last_5years": [promotion_last_5years],
62
  "salary": [salary],
 
 
63
  **department_features
64
  }
65
 
@@ -127,3 +124,4 @@ def gradio_interface():
127
  interface.launch()
128
 
129
  gradio_interface()
 
 
8
  import shap
9
  import matplotlib.pyplot as plt
10
  from sklearn.metrics import roc_curve, precision_recall_curve
11
+ from imblearn.over_sampling import SMOTE
12
 
13
  # Suppress XGBoost warnings
14
  warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
 
27
 
28
  model = load_model()
29
 
30
+ # Prediction function with dynamic threshold and balanced data
 
 
 
 
 
 
 
 
 
 
31
  def predict_employee_status(satisfaction_level, last_evaluation, number_project,
32
  average_monthly_hours, time_spent_company,
33
  work_accident, promotion_last_5years, salary, department, threshold=0.5):
 
41
  if department in departments:
42
  department_features[f"department_{department}"] = 1
43
 
44
+ # Automatically Generate Interaction Features
45
+ satisfaction_evaluation = satisfaction_level * last_evaluation
46
+ work_balance = average_monthly_hours / number_project
47
+
48
+ # Prepare the input with all expected features as a DataFrame with column names
49
  input_data = {
50
  "satisfaction_level": [satisfaction_level],
51
  "last_evaluation": [last_evaluation],
 
55
  "Work_accident": [work_accident],
56
  "promotion_last_5years": [promotion_last_5years],
57
  "salary": [salary],
58
+ "satisfaction_evaluation": [satisfaction_evaluation],
59
+ "work_balance": [work_balance],
60
  **department_features
61
  }
62
 
 
124
  interface.launch()
125
 
126
  gradio_interface()
127
+