import os import joblib import numpy as np from concrete.ml.deployment import FHEModelClient, FHEModelServer import logging import gradio as gr # Configure logging logging.basicConfig(level=logging.INFO) key_already_generated_condition = False encrypted_data = None encrypted_prediction = None # Paths to required files SCALER_PATH = os.path.join("models", "scaler_random_forest.pkl") FHE_FILES_PATH = os.path.join("models", "fhe_files") # Load the scaler try: scaler = joblib.load(SCALER_PATH) logging.info("Scaler loaded successfully.") except FileNotFoundError: logging.error(f"Error: The file scaler.pkl is missing at {SCALER_PATH}.") raise # Initialize the FHE client and server try: client = FHEModelClient(path_dir=FHE_FILES_PATH, key_dir=FHE_FILES_PATH) server = FHEModelServer(path_dir=FHE_FILES_PATH) server.load() logging.info("FHE Client and Server initialized successfully.") except FileNotFoundError: logging.error(f"Error: The FHE files (client.zip, server.zip) are missing in {FHE_FILES_PATH}.") raise # Load evaluation keys evaluation_keys = client.get_serialized_evaluation_keys() def predict(): """ Perform a local prediction using the compiled FHE model. Returns: str: The prediction result. str: A message indicating the status of the prediction. """ global encrypted_data, encrypted_prediction if encrypted_data is None: return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌") try: # Execute the model locally on encrypted data encrypted_prediction = server.run( encrypted_data, serialized_evaluation_keys=evaluation_keys ) logging.info(f"Encrypted Prediction: {encrypted_prediction}") return encrypted_prediction.hex(), gr.update(value="FHE evaluation is done. ✅") except Exception as e: logging.error(f"Error during prediction: {e}") return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌") def decrypt_prediction(): """ Decrypt and interpret the prediction result. Returns: str: The interpreted prediction result. """ global encrypted_prediction if encrypted_prediction is None: return "No prediction to decrypt. Please make a prediction first. ❌", "No prediction to decrypt. Please make a prediction first. ❌" try: # Decrypt the prediction result decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction) logging.info(f"Decrypted Prediction: {decrypted_prediction}") # Interpret the prediction binary_prediction = int(np.argmax(decrypted_prediction)) # Ensure the prediction is a flat array if isinstance(decrypted_prediction, np.ndarray) and decrypted_prediction.ndim > 1: decrypted_prediction = decrypted_prediction.flatten() # Generate the HTML for the percentage bar bar_html = f"""
""" return "⚠️ Fraudulent ⚠️" if binary_prediction == 1 else "😊 Non-fraudulent 😊", gr.update(value="Decryption successful ✅"), bar_html except Exception as e: logging.error(f"Error during prediction: {e}") return "Error during prediction❌", "Error during prediction❌","Error during prediction❌" def key_already_generated(): """ Check if the evaluation keys have already been generated. Returns: bool: True if the evaluation keys have already been generated, False otherwise. """ global key_already_generated_condition if evaluation_keys: key_already_generated_condition = True return True return False def pre_process_encrypt_send_purchase(*inputs): """ Pre-processes, encrypts, and sends the purchase data for prediction. Args: *inputs: Variable number of input arguments. Returns: (str): A short representation of the encrypted input to send in hex. """ global key_already_generated_condition, encrypted_data if key_already_generated_condition == False: return None, gr.update(value="Generate your key before. ❌") try: key_already_generated_condition = True logging.info(f"Input Data: {inputs}") # Scale the input data scaled_data = scaler.transform([list(inputs)]) logging.info(f"Scaled Data: {scaled_data}") # Encrypt the scaled data encrypted_data = client.quantize_encrypt_serialize(scaled_data) logging.info("Data encrypted successfully.") return encrypted_data.hex(), gr.update(value="Inputs are encrypted and sent to server. ✅") except Exception as e: logging.error(f"Error during pre-processing: {e}") return "Error during pre-processing"