|
import os |
|
import joblib |
|
import numpy as np |
|
from concrete.ml.deployment import FHEModelClient, FHEModelServer |
|
import logging |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
SCALER_PATH = os.path.join("models", "scaler.pkl") |
|
FHE_FILES_PATH = os.path.join("models", "fhe_files") |
|
|
|
|
|
try: |
|
scaler = joblib.load(SCALER_PATH) |
|
logging.info("Scaler loaded successfully.") |
|
except FileNotFoundError: |
|
logging.error(f"Error: The file scaler.pkl is missing at {SCALER_PATH}.") |
|
raise |
|
|
|
|
|
try: |
|
client = FHEModelClient(path_dir=FHE_FILES_PATH, key_dir=FHE_FILES_PATH) |
|
server = FHEModelServer(path_dir=FHE_FILES_PATH) |
|
server.load() |
|
logging.info("FHE Client and Server initialized successfully.") |
|
except FileNotFoundError: |
|
logging.error(f"Error: The FHE files (client.zip, server.zip) are missing in {FHE_FILES_PATH}.") |
|
raise |
|
|
|
|
|
evaluation_keys = client.get_serialized_evaluation_keys() |
|
|
|
def predict(input_data): |
|
""" |
|
Perform a local prediction using the compiled FHE model. |
|
|
|
Args: |
|
input_data (dict): User input data as a dictionary. |
|
|
|
Returns: |
|
str: Prediction result ("Fraudulent" or "Non-fraudulent"). |
|
""" |
|
try: |
|
logging.info(f"Input Data: {input_data}") |
|
|
|
|
|
scaled_data = scaler.transform([list(input_data.values())]) |
|
logging.info(f"Scaled Data: {scaled_data}") |
|
|
|
|
|
encrypted_data = client.quantize_encrypt_serialize(scaled_data) |
|
logging.info("Data encrypted successfully.") |
|
|
|
|
|
encrypted_prediction = server.run( |
|
encrypted_data, serialized_evaluation_keys=evaluation_keys |
|
) |
|
logging.info(f"Encrypted Prediction: {encrypted_prediction}") |
|
|
|
|
|
decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction) |
|
logging.info(f"Decrypted Prediction: {decrypted_prediction}") |
|
|
|
|
|
binary_prediction = int(np.argmax(decrypted_prediction)) |
|
return "Fraudulent" if binary_prediction == 1 else "Non-fraudulent" |
|
except Exception as e: |
|
logging.error(f"Error during prediction: {e}") |
|
return "Error during prediction" |
|
|