Tenefix commited on
Commit
cf5e9c7
·
verified ·
1 Parent(s): 630ff31

Create predictor.py

Browse files
Files changed (1) hide show
  1. predictor.py +71 -0
predictor.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import joblib
3
+ import numpy as np
4
+ from concrete.ml.deployment import FHEModelClient, FHEModelServer
5
+ import logging
6
+
7
+ # Configure logging
8
+ logging.basicConfig(level=logging.INFO)
9
+
10
+ # Paths to required files
11
+ SCALER_PATH = os.path.join("models", "scaler.pkl")
12
+ FHE_FILES_PATH = os.path.join("models", "fhe_files")
13
+
14
+ # Load the scaler
15
+ try:
16
+ scaler = joblib.load(SCALER_PATH)
17
+ logging.info("Scaler loaded successfully.")
18
+ except FileNotFoundError:
19
+ logging.error(f"Error: The file scaler.pkl is missing at {SCALER_PATH}.")
20
+ raise
21
+
22
+ # Initialize the FHE client and server
23
+ try:
24
+ client = FHEModelClient(path_dir=FHE_FILES_PATH, key_dir=FHE_FILES_PATH)
25
+ server = FHEModelServer(path_dir=FHE_FILES_PATH)
26
+ server.load()
27
+ logging.info("FHE Client and Server initialized successfully.")
28
+ except FileNotFoundError:
29
+ logging.error(f"Error: The FHE files (client.zip, server.zip) are missing in {FHE_FILES_PATH}.")
30
+ raise
31
+
32
+ # Load evaluation keys
33
+ evaluation_keys = client.get_serialized_evaluation_keys()
34
+
35
+ def predict(input_data):
36
+ """
37
+ Perform a local prediction using the compiled FHE model.
38
+
39
+ Args:
40
+ input_data (dict): User input data as a dictionary.
41
+
42
+ Returns:
43
+ str: Prediction result ("Fraudulent" or "Non-fraudulent").
44
+ """
45
+ try:
46
+ logging.info(f"Input Data: {input_data}")
47
+
48
+ # Scale the input data
49
+ scaled_data = scaler.transform([list(input_data.values())])
50
+ logging.info(f"Scaled Data: {scaled_data}")
51
+
52
+ # Encrypt the scaled data
53
+ encrypted_data = client.quantize_encrypt_serialize(scaled_data)
54
+ logging.info("Data encrypted successfully.")
55
+
56
+ # Execute the model locally on encrypted data
57
+ encrypted_prediction = server.run(
58
+ encrypted_data, serialized_evaluation_keys=evaluation_keys
59
+ )
60
+ logging.info(f"Encrypted Prediction: {encrypted_prediction}")
61
+
62
+ # Decrypt the prediction result
63
+ decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction)
64
+ logging.info(f"Decrypted Prediction: {decrypted_prediction}")
65
+
66
+ # Interpret the prediction
67
+ binary_prediction = int(np.argmax(decrypted_prediction))
68
+ return "Fraudulent" if binary_prediction == 1 else "Non-fraudulent"
69
+ except Exception as e:
70
+ logging.error(f"Error during prediction: {e}")
71
+ return "Error during prediction"