|
import os |
|
import joblib |
|
import numpy as np |
|
from concrete.ml.deployment import FHEModelClient, FHEModelServer |
|
import logging |
|
import gradio as gr |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
key_already_generated_condition = False |
|
encrypted_data = None |
|
encrypted_prediction = None |
|
|
|
|
|
SCALER_PATH = os.path.join("models", "scaler_random_forest.pkl") |
|
FHE_FILES_PATH = os.path.join("models", "fhe_files") |
|
|
|
|
|
try: |
|
scaler = joblib.load(SCALER_PATH) |
|
logging.info("Scaler loaded successfully.") |
|
except FileNotFoundError: |
|
logging.error(f"Error: The file scaler.pkl is missing at {SCALER_PATH}.") |
|
raise |
|
|
|
|
|
try: |
|
client = FHEModelClient(path_dir=FHE_FILES_PATH, key_dir=FHE_FILES_PATH) |
|
server = FHEModelServer(path_dir=FHE_FILES_PATH) |
|
server.load() |
|
logging.info("FHE Client and Server initialized successfully.") |
|
except FileNotFoundError: |
|
logging.error(f"Error: The FHE files (client.zip, server.zip) are missing in {FHE_FILES_PATH}.") |
|
raise |
|
|
|
|
|
evaluation_keys = client.get_serialized_evaluation_keys() |
|
|
|
def predict(): |
|
""" |
|
Perform a local prediction using the compiled FHE model. |
|
Returns: |
|
str: The prediction result. |
|
str: A message indicating the status of the prediction. |
|
""" |
|
|
|
global encrypted_data, encrypted_prediction |
|
if encrypted_data is None: |
|
return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌") |
|
try: |
|
|
|
encrypted_prediction = server.run( |
|
encrypted_data, serialized_evaluation_keys=evaluation_keys |
|
) |
|
logging.info(f"Encrypted Prediction: {encrypted_prediction}") |
|
return encrypted_prediction.hex(), gr.update(value="FHE evaluation is done. ✅") |
|
|
|
except Exception as e: |
|
logging.error(f"Error during prediction: {e}") |
|
return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌") |
|
|
|
def decrypt_prediction(): |
|
""" |
|
Decrypt and interpret the prediction result. |
|
Returns: |
|
str: The interpreted prediction result. |
|
""" |
|
global encrypted_prediction |
|
if encrypted_prediction is None: |
|
return "No prediction to decrypt. Please make a prediction first. ❌", "No prediction to decrypt. Please make a prediction first. ❌" |
|
try: |
|
|
|
decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction) |
|
logging.info(f"Decrypted Prediction: {decrypted_prediction}") |
|
|
|
|
|
binary_prediction = int(np.argmax(decrypted_prediction)) |
|
|
|
if isinstance(decrypted_prediction, np.ndarray) and decrypted_prediction.ndim > 1: |
|
decrypted_prediction = decrypted_prediction.flatten() |
|
|
|
|
|
bar_html = f""" |
|
<div style="width: 100%; background-color: lightgray; border-radius: 5px; overflow: hidden; display: flex;"> |
|
<div style="width: {decrypted_prediction[0] * 100}%; background-color: green; color: white; text-align: center; padding: 5px 0;"> |
|
{decrypted_prediction[0] * 100:.1f}% Non-Fraud |
|
</div> |
|
<div style="width: {decrypted_prediction[1] * 100}%; background-color: red; color: white; text-align: center; padding: 5px 0;"> |
|
{decrypted_prediction[1] * 100:.1f}% Fraud |
|
</div> |
|
</div> |
|
""" |
|
return "⚠️ Fraudulent ⚠️" if binary_prediction == 1 else "😊 Non-fraudulent 😊", gr.update(value="Decryption successful ✅"), bar_html |
|
|
|
except Exception as e: |
|
logging.error(f"Error during prediction: {e}") |
|
return "Error during prediction❌", "Error during prediction❌","Error during prediction❌" |
|
|
|
def key_already_generated(): |
|
""" |
|
Check if the evaluation keys have already been generated. |
|
Returns: |
|
bool: True if the evaluation keys have already been generated, False otherwise. |
|
""" |
|
global key_already_generated_condition |
|
if evaluation_keys: |
|
key_already_generated_condition = True |
|
return True |
|
return False |
|
|
|
def pre_process_encrypt_send_purchase(*inputs): |
|
""" |
|
Pre-processes, encrypts, and sends the purchase data for prediction. |
|
Args: |
|
*inputs: Variable number of input arguments. |
|
Returns: |
|
(str): A short representation of the encrypted input to send in hex. |
|
""" |
|
global key_already_generated_condition, encrypted_data |
|
if key_already_generated_condition == False: |
|
return None, gr.update(value="Generate your key before. ❌") |
|
try: |
|
key_already_generated_condition = True |
|
logging.info(f"Input Data: {inputs}") |
|
|
|
|
|
scaled_data = scaler.transform([list(inputs)]) |
|
logging.info(f"Scaled Data: {scaled_data}") |
|
|
|
|
|
encrypted_data = client.quantize_encrypt_serialize(scaled_data) |
|
logging.info("Data encrypted successfully.") |
|
return encrypted_data.hex(), gr.update(value="Inputs are encrypted and sent to server. ✅") |
|
except Exception as e: |
|
logging.error(f"Error during pre-processing: {e}") |
|
return "Error during pre-processing" |
|
|