File size: 5,445 Bytes
cf5e9c7 43cc119 cf5e9c7 43cc119 cf5e9c7 e788295 cf5e9c7 43cc119 cf5e9c7 43cc119 cf5e9c7 ad565c9 43cc119 cf5e9c7 e05b6fd 43cc119 cf5e9c7 f2eadfb 43cc119 cf5e9c7 f2eadfb 43cc119 e05b6fd 43cc119 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import joblib
import numpy as np
from concrete.ml.deployment import FHEModelClient, FHEModelServer
import logging
import gradio as gr
# Configure logging
logging.basicConfig(level=logging.INFO)
key_already_generated_condition = False
encrypted_data = None
encrypted_prediction = None
# Paths to required files
SCALER_PATH = os.path.join("models", "scaler_random_forest.pkl")
FHE_FILES_PATH = os.path.join("models", "fhe_files")
# Load the scaler
try:
scaler = joblib.load(SCALER_PATH)
logging.info("Scaler loaded successfully.")
except FileNotFoundError:
logging.error(f"Error: The file scaler.pkl is missing at {SCALER_PATH}.")
raise
# Initialize the FHE client and server
try:
client = FHEModelClient(path_dir=FHE_FILES_PATH, key_dir=FHE_FILES_PATH)
server = FHEModelServer(path_dir=FHE_FILES_PATH)
server.load()
logging.info("FHE Client and Server initialized successfully.")
except FileNotFoundError:
logging.error(f"Error: The FHE files (client.zip, server.zip) are missing in {FHE_FILES_PATH}.")
raise
# Load evaluation keys
evaluation_keys = client.get_serialized_evaluation_keys()
def predict():
"""
Perform a local prediction using the compiled FHE model.
Returns:
str: The prediction result.
str: A message indicating the status of the prediction.
"""
global encrypted_data, encrypted_prediction
if encrypted_data is None:
return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌")
try:
# Execute the model locally on encrypted data
encrypted_prediction = server.run(
encrypted_data, serialized_evaluation_keys=evaluation_keys
)
logging.info(f"Encrypted Prediction: {encrypted_prediction}")
return encrypted_prediction.hex(), gr.update(value="FHE evaluation is done. ✅")
except Exception as e:
logging.error(f"Error during prediction: {e}")
return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌")
def decrypt_prediction():
"""
Decrypt and interpret the prediction result.
Returns:
str: The interpreted prediction result.
"""
global encrypted_prediction
if encrypted_prediction is None:
return "No prediction to decrypt. Please make a prediction first. ❌", "No prediction to decrypt. Please make a prediction first. ❌"
try:
# Decrypt the prediction result
decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction)
logging.info(f"Decrypted Prediction: {decrypted_prediction}")
# Interpret the prediction
binary_prediction = int(np.argmax(decrypted_prediction))
# Ensure the prediction is a flat array
if isinstance(decrypted_prediction, np.ndarray) and decrypted_prediction.ndim > 1:
decrypted_prediction = decrypted_prediction.flatten()
# Generate the HTML for the percentage bar
bar_html = f"""
<div style="width: 100%; background-color: lightgray; border-radius: 5px; overflow: hidden; display: flex;">
<div style="width: {decrypted_prediction[0] * 100}%; background-color: green; color: white; text-align: center; padding: 5px 0;">
{decrypted_prediction[0] * 100:.1f}% Non-Fraud
</div>
<div style="width: {decrypted_prediction[1] * 100}%; background-color: red; color: white; text-align: center; padding: 5px 0;">
{decrypted_prediction[1] * 100:.1f}% Fraud
</div>
</div>
"""
return "⚠️ Fraudulent ⚠️" if binary_prediction == 1 else "😊 Non-fraudulent 😊", gr.update(value="Decryption successful ✅"), bar_html
except Exception as e:
logging.error(f"Error during prediction: {e}")
return "Error during prediction❌", "Error during prediction❌","Error during prediction❌"
def key_already_generated():
"""
Check if the evaluation keys have already been generated.
Returns:
bool: True if the evaluation keys have already been generated, False otherwise.
"""
global key_already_generated_condition
if evaluation_keys:
key_already_generated_condition = True
return True
return False
def pre_process_encrypt_send_purchase(*inputs):
"""
Pre-processes, encrypts, and sends the purchase data for prediction.
Args:
*inputs: Variable number of input arguments.
Returns:
(str): A short representation of the encrypted input to send in hex.
"""
global key_already_generated_condition, encrypted_data
if key_already_generated_condition == False:
return None, gr.update(value="Generate your key before. ❌")
try:
key_already_generated_condition = True
logging.info(f"Input Data: {inputs}")
# Scale the input data
scaled_data = scaler.transform([list(inputs)])
logging.info(f"Scaled Data: {scaled_data}")
# Encrypt the scaled data
encrypted_data = client.quantize_encrypt_serialize(scaled_data)
logging.info("Data encrypted successfully.")
return encrypted_data.hex(), gr.update(value="Inputs are encrypted and sent to server. ✅")
except Exception as e:
logging.error(f"Error during pre-processing: {e}")
return "Error during pre-processing"
|