Zeyadd-Mostaffa commited on
Commit
2750f6c
Β·
verified Β·
1 Parent(s): 7a63bb7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -10
app.py CHANGED
@@ -1,13 +1,32 @@
1
  import gradio as gr
2
- import joblib
3
  import numpy as np
 
 
 
 
 
 
4
 
5
- # Load your model
6
- model = joblib.load('best_model.json')
 
 
 
 
 
 
 
 
 
7
 
8
- def predict_retention(satisfaction_level, last_evaluation, number_project,
9
- average_monthly_hours, time_spent_company,
10
- work_accident, promotion_last_5years, salary, department):
 
 
 
 
11
  # One-hot encode the department
12
  departments = [
13
  'RandD', 'accounting', 'hr', 'management', 'marketing',
@@ -23,14 +42,19 @@ def predict_retention(satisfaction_level, last_evaluation, number_project,
23
  ] + department_encoded).reshape(1, -1)
24
 
25
  # Predict using the model
 
 
 
26
  try:
27
- prediction = model.predict(input_data)
28
- return "Employee is likely to quit." if prediction[0] == 1 else "Employee is likely to stay."
 
29
  except Exception as e:
30
- return f"Error: {str(e)}"
31
 
 
32
  interface = gr.Interface(
33
- fn=predict_retention,
34
  inputs=[
35
  gr.Number(label="Satisfaction Level (0.0 - 1.0)"),
36
  gr.Number(label="Last Evaluation (0.0 - 1.0)"),
 
1
  import gradio as gr
2
+ import xgboost as xgb
3
  import numpy as np
4
+ import joblib
5
+ import os
6
+ import warnings
7
+
8
+ # Suppress XGBoost warnings
9
+ warnings.filterwarnings("ignore", category=UserWarning, message=".*WARNING.*")
10
 
11
+ # Load your model (automatically detect XGBoost or joblib model)
12
+ def load_model():
13
+ model_path = "best_model.json" # Ensure this matches your file name
14
+ if os.path.exists(model_path):
15
+ model = xgb.Booster()
16
+ model.load_model(model_path)
17
+ print("βœ… Model loaded successfully.")
18
+ return model
19
+ else:
20
+ print("❌ Model file not found.")
21
+ return None
22
 
23
+ model = load_model()
24
+
25
+ # Prediction function
26
+ def predict_employee_status(satisfaction_level, last_evaluation, number_project,
27
+ average_monthly_hours, time_spent_company,
28
+ work_accident, promotion_last_5years, salary, department):
29
+
30
  # One-hot encode the department
31
  departments = [
32
  'RandD', 'accounting', 'hr', 'management', 'marketing',
 
42
  ] + department_encoded).reshape(1, -1)
43
 
44
  # Predict using the model
45
+ if model is None:
46
+ return "❌ No model found. Please upload the model file."
47
+
48
  try:
49
+ dmatrix = xgb.DMatrix(input_data)
50
+ prediction = model.predict(dmatrix)
51
+ return "Employee is likely to quit." if prediction[0] > 0.5 else "Employee is likely to stay."
52
  except Exception as e:
53
+ return f"❌ Error: {str(e)}"
54
 
55
+ # Gradio interface
56
  interface = gr.Interface(
57
+ fn=predict_employee_status,
58
  inputs=[
59
  gr.Number(label="Satisfaction Level (0.0 - 1.0)"),
60
  gr.Number(label="Last Evaluation (0.0 - 1.0)"),