Zeyadd-Mostaffa commited on
Commit
abb49f7
·
verified ·
1 Parent(s): 6cde6ed

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -20
app.py CHANGED
@@ -26,21 +26,27 @@ def predict_employee_status(
26
  average_monthly_hours, time_spend_company,
27
  work_accident, promotion_last_5years, salary, department, threshold=0.5
28
  ):
29
- departments = [
30
- 'IT', 'RandD', 'accounting', 'hr', 'management',
31
- 'marketing', 'product_mng', 'sales', 'support', 'technical'
 
 
 
 
 
32
  ]
33
 
34
- # One-hot encode department
35
- department_features = {f"department_{dept}": 0 for dept in departments}
36
- if department in departments:
37
- department_features[f"department_{department}"] = 1
 
38
 
39
- # Interaction features
40
  satisfaction_evaluation = satisfaction_level * last_evaluation
41
  work_balance = average_monthly_hours / number_project
42
 
43
- # Input data
44
  input_data = {
45
  "satisfaction_level": [satisfaction_level],
46
  "last_evaluation": [last_evaluation],
@@ -57,15 +63,7 @@ def predict_employee_status(
57
 
58
  input_df = pd.DataFrame(input_data)
59
 
60
- # Ensure exact column order
61
- expected_columns = [
62
- 'satisfaction_level', 'last_evaluation', 'number_project', 'average_monthly_hours',
63
- 'time_spend_company', 'Work_accident', 'promotion_last_5years', 'salary',
64
- 'satisfaction_evaluation', 'work_balance',
65
- 'department_IT', 'department_RandD', 'department_accounting', 'department_hr',
66
- 'department_management', 'department_marketing', 'department_product_mng',
67
- 'department_sales', 'department_support', 'department_technical'
68
- ]
69
  for col in expected_columns:
70
  if col not in input_df.columns:
71
  input_df[col] = 0
@@ -92,8 +90,8 @@ def gradio_interface():
92
  gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
93
  gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
94
  gr.Dropdown(
95
- ['sales', 'accounting', 'hr', 'technical', 'support',
96
- 'management', 'IT', 'product_mng', 'marketing', 'RandD'],
97
  label="Department"
98
  ),
99
  gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
@@ -107,3 +105,4 @@ def gradio_interface():
107
 
108
  gradio_interface()
109
 
 
 
26
  average_monthly_hours, time_spend_company,
27
  work_accident, promotion_last_5years, salary, department, threshold=0.5
28
  ):
29
+ # Expected columns from training
30
+ expected_columns = [
31
+ 'satisfaction_level', 'last_evaluation', 'number_project', 'average_monthly_hours',
32
+ 'time_spend_company', 'Work_accident', 'promotion_last_5years', 'salary',
33
+ 'satisfaction_evaluation', 'work_balance',
34
+ 'department_IT', 'department_RandD', 'department_accounting', 'department_hr',
35
+ 'department_management', 'department_marketing', 'department_product_mng',
36
+ 'department_sales', 'department_support', 'department_technical'
37
  ]
38
 
39
+ # Construct department one-hot features
40
+ department_features = {col: 0 for col in expected_columns if col.startswith("department_")}
41
+ dept_key = f"department_{department}"
42
+ if dept_key in department_features:
43
+ department_features[dept_key] = 1
44
 
45
+ # Create interaction features
46
  satisfaction_evaluation = satisfaction_level * last_evaluation
47
  work_balance = average_monthly_hours / number_project
48
 
49
+ # Create input dataframe
50
  input_data = {
51
  "satisfaction_level": [satisfaction_level],
52
  "last_evaluation": [last_evaluation],
 
63
 
64
  input_df = pd.DataFrame(input_data)
65
 
66
+ # Ensure all expected columns are present and ordered
 
 
 
 
 
 
 
 
67
  for col in expected_columns:
68
  if col not in input_df.columns:
69
  input_df[col] = 0
 
90
  gr.Radio([0, 1], label="Promotion in Last 5 Years (0 = No, 1 = Yes)"),
91
  gr.Radio([0, 1, 2], label="Salary (0 = Low, 1 = Medium, 2 = High)"),
92
  gr.Dropdown(
93
+ ['IT', 'RandD', 'accounting', 'hr', 'management',
94
+ 'marketing', 'product_mng', 'sales', 'support', 'technical'],
95
  label="Department"
96
  ),
97
  gr.Slider(0.1, 0.9, value=0.5, step=0.05, label="Prediction Threshold")
 
105
 
106
  gradio_interface()
107
 
108
+