xeroISB commited on
Commit
d05b013
·
1 Parent(s): 1850e78

Initial commit

Browse files
Files changed (3) hide show
  1. cox_model.pkl +0 -0
  2. inference.py +37 -0
  3. requirements.txt +3 -0
cox_model.pkl ADDED
Binary file (200 kB). View file
 
inference.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import joblib
2
+ import pandas as pd
3
+ from lifelines import CoxPHFitter
4
+
5
+ class HRAttritionModel:
6
+ def __init__(self, model_path):
7
+ self.model = joblib.load(model_path)
8
+ self.features = ['Age', 'DistanceFromHome', 'Education', 'NumCompaniesWorked', 'PercentSalaryHike',
9
+ 'TotalWorkingYears', 'TrainingTimesLastYear', 'WorkLifeBalance', 'YearsInCurrentRole',
10
+ 'YearsSinceLastPromotion', 'YearsWithCurrManager', 'BusinessTravel_Travel_Rarely',
11
+ 'BusinessTravel_Travel_Frequently', 'Department_Research & Development', 'Department_Sales',
12
+ 'EducationField_Life Sciences', 'EducationField_Medical', 'EducationField_Marketing',
13
+ 'EducationField_Other', 'EducationField_Technical Degree', 'Gender_Male', 'JobRole_Research Scientist',
14
+ 'JobRole_Sales Executive', 'JobRole_Laboratory Technician', 'JobRole_Manufacturing Director',
15
+ 'JobRole_Healthcare Representative', 'JobRole_Manager', 'JobRole_Sales Representative',
16
+ 'JobRole_Research Director', 'MaritalStatus_Married', 'MaritalStatus_Single', 'OverTime_Yes']
17
+
18
+ def predict_survival(self, input_data):
19
+ df = pd.DataFrame([input_data], columns=self.features)
20
+ survival_function = self.model.predict_survival_function(df)
21
+ return survival_function.T
22
+
23
+ # Load the model and make a prediction for testing
24
+ if __name__ == "__main__":
25
+ model = HRAttritionModel('cox_model.pkl')
26
+ sample_input = {'Age': 41, 'DistanceFromHome': 1, 'Education': 2, 'NumCompaniesWorked': 1, 'PercentSalaryHike': 11,
27
+ 'TotalWorkingYears': 8, 'TrainingTimesLastYear': 0, 'WorkLifeBalance': 1, 'YearsInCurrentRole': 4,
28
+ 'YearsSinceLastPromotion': 0, 'YearsWithCurrManager': 5, 'BusinessTravel_Travel_Rarely': 1,
29
+ 'BusinessTravel_Travel_Frequently': 0, 'Department_Research & Development': 0, 'Department_Sales': 1,
30
+ 'EducationField_Life Sciences': 1, 'EducationField_Medical': 0, 'EducationField_Marketing': 0,
31
+ 'EducationField_Other': 0, 'EducationField_Technical Degree': 0, 'Gender_Male': 1,
32
+ 'JobRole_Research Scientist': 0, 'JobRole_Sales Executive': 0, 'JobRole_Laboratory Technician': 0,
33
+ 'JobRole_Manufacturing Director': 0, 'JobRole_Healthcare Representative': 0, 'JobRole_Manager': 0,
34
+ 'JobRole_Sales Representative': 0, 'JobRole_Research Director': 0, 'MaritalStatus_Married': 0,
35
+ 'MaritalStatus_Single': 1, 'OverTime_Yes': 0}
36
+ prediction = model.predict_survival(sample_input)
37
+ print(prediction)
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ pandas
2
+ joblib
3
+ lifelines