Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,6 +8,7 @@ import warnings
|
|
8 |
import shap
|
9 |
import matplotlib.pyplot as plt
|
10 |
from sklearn.metrics import roc_curve, precision_recall_curve
|
|
|
11 |
|
12 |
# Suppress XGBoost warnings
|
13 |
warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
|
@@ -26,17 +27,7 @@ def load_model():
|
|
26 |
|
27 |
model = load_model()
|
28 |
|
29 |
-
#
|
30 |
-
def optimize_threshold(X_test, y_test):
|
31 |
-
dmatrix = xgb.DMatrix(X_test)
|
32 |
-
y_prob = model.predict(dmatrix)
|
33 |
-
|
34 |
-
fpr, tpr, thresholds = roc_curve(y_test, y_prob)
|
35 |
-
optimal_idx = np.argmax(tpr - fpr)
|
36 |
-
optimal_threshold = thresholds[optimal_idx]
|
37 |
-
return optimal_threshold
|
38 |
-
|
39 |
-
# Prediction function with dynamic threshold
|
40 |
def predict_employee_status(satisfaction_level, last_evaluation, number_project,
|
41 |
average_monthly_hours, time_spent_company,
|
42 |
work_accident, promotion_last_5years, salary, department, threshold=0.5):
|
@@ -50,7 +41,11 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
|
|
50 |
if department in departments:
|
51 |
department_features[f"department_{department}"] = 1
|
52 |
|
53 |
-
#
|
|
|
|
|
|
|
|
|
54 |
input_data = {
|
55 |
"satisfaction_level": [satisfaction_level],
|
56 |
"last_evaluation": [last_evaluation],
|
@@ -60,6 +55,8 @@ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
|
|
60 |
"Work_accident": [work_accident],
|
61 |
"promotion_last_5years": [promotion_last_5years],
|
62 |
"salary": [salary],
|
|
|
|
|
63 |
**department_features
|
64 |
}
|
65 |
|
@@ -127,3 +124,4 @@ def gradio_interface():
|
|
127 |
interface.launch()
|
128 |
|
129 |
gradio_interface()
|
|
|
|
8 |
import shap
|
9 |
import matplotlib.pyplot as plt
|
10 |
from sklearn.metrics import roc_curve, precision_recall_curve
|
11 |
+
from imblearn.over_sampling import SMOTE
|
12 |
|
13 |
# Suppress XGBoost warnings
|
14 |
warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
|
|
|
27 |
|
28 |
model = load_model()
|
29 |
|
30 |
+
# Prediction function with dynamic threshold and balanced data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def predict_employee_status(satisfaction_level, last_evaluation, number_project,
|
32 |
average_monthly_hours, time_spent_company,
|
33 |
work_accident, promotion_last_5years, salary, department, threshold=0.5):
|
|
|
41 |
if department in departments:
|
42 |
department_features[f"department_{department}"] = 1
|
43 |
|
44 |
+
# Automatically Generate Interaction Features
|
45 |
+
satisfaction_evaluation = satisfaction_level * last_evaluation
|
46 |
+
work_balance = average_monthly_hours / number_project
|
47 |
+
|
48 |
+
# Prepare the input with all expected features as a DataFrame with column names
|
49 |
input_data = {
|
50 |
"satisfaction_level": [satisfaction_level],
|
51 |
"last_evaluation": [last_evaluation],
|
|
|
55 |
"Work_accident": [work_accident],
|
56 |
"promotion_last_5years": [promotion_last_5years],
|
57 |
"salary": [salary],
|
58 |
+
"satisfaction_evaluation": [satisfaction_evaluation],
|
59 |
+
"work_balance": [work_balance],
|
60 |
**department_features
|
61 |
}
|
62 |
|
|
|
124 |
interface.launch()
|
125 |
|
126 |
gradio_interface()
|
127 |
+
|