File size: 5,445 Bytes
cf5e9c7
 
 
 
 
43cc119
cf5e9c7
 
 
43cc119
 
 
cf5e9c7
 
e788295
cf5e9c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43cc119
cf5e9c7
 
 
43cc119
 
cf5e9c7
ad565c9
43cc119
 
 
cf5e9c7
 
 
 
 
 
e05b6fd
43cc119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf5e9c7
 
 
 
 
 
f2eadfb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43cc119
cf5e9c7
 
f2eadfb
43cc119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e05b6fd
43cc119
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import joblib
import numpy as np
from concrete.ml.deployment import FHEModelClient, FHEModelServer
import logging
import gradio as gr

# Configure logging
logging.basicConfig(level=logging.INFO)
key_already_generated_condition = False
encrypted_data = None
encrypted_prediction = None

# Paths to required files
SCALER_PATH = os.path.join("models", "scaler_random_forest.pkl")
FHE_FILES_PATH = os.path.join("models", "fhe_files")

# Load the scaler
try:
    scaler = joblib.load(SCALER_PATH)
    logging.info("Scaler loaded successfully.")
except FileNotFoundError:
    logging.error(f"Error: The file scaler.pkl is missing at {SCALER_PATH}.")
    raise

# Initialize the FHE client and server
try:
    client = FHEModelClient(path_dir=FHE_FILES_PATH, key_dir=FHE_FILES_PATH)
    server = FHEModelServer(path_dir=FHE_FILES_PATH)
    server.load()
    logging.info("FHE Client and Server initialized successfully.")
except FileNotFoundError:
    logging.error(f"Error: The FHE files (client.zip, server.zip) are missing in {FHE_FILES_PATH}.")
    raise

# Load evaluation keys
evaluation_keys = client.get_serialized_evaluation_keys()

def predict():
    """
    Perform a local prediction using the compiled FHE model.
    Returns:
        str: The prediction result.
        str: A message indicating the status of the prediction.
    """
        
    global encrypted_data, encrypted_prediction
    if encrypted_data is None:
        return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌")
    try:
        # Execute the model locally on encrypted data
        encrypted_prediction = server.run(
            encrypted_data, serialized_evaluation_keys=evaluation_keys
        )
        logging.info(f"Encrypted Prediction: {encrypted_prediction}")
        return encrypted_prediction.hex(), gr.update(value="FHE evaluation is done. ✅")
        
    except Exception as e:
        logging.error(f"Error during prediction: {e}")
        return None, gr.update(value="No encrypted data to predict. Please provide encrypted data ❌")
    
def decrypt_prediction():
    """
    Decrypt and interpret the prediction result.
    Returns:
        str: The interpreted prediction result.
    """
    global encrypted_prediction
    if encrypted_prediction is None:
        return "No prediction to decrypt. Please make a prediction first. ❌", "No prediction to decrypt. Please make a prediction first. ❌"
    try:
        # Decrypt the prediction result
        decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_prediction)
        logging.info(f"Decrypted Prediction: {decrypted_prediction}")

        # Interpret the prediction
        binary_prediction = int(np.argmax(decrypted_prediction))
        # Ensure the prediction is a flat array
        if isinstance(decrypted_prediction, np.ndarray) and decrypted_prediction.ndim > 1:
            decrypted_prediction = decrypted_prediction.flatten()

        # Generate the HTML for the percentage bar
        bar_html = f"""
        <div style="width: 100%; background-color: lightgray; border-radius: 5px; overflow: hidden; display: flex;">
            <div style="width: {decrypted_prediction[0] * 100}%; background-color: green; color: white; text-align: center; padding: 5px 0;">
                {decrypted_prediction[0] * 100:.1f}% Non-Fraud
            </div>
            <div style="width: {decrypted_prediction[1] * 100}%; background-color: red; color: white; text-align: center; padding: 5px 0;">
                {decrypted_prediction[1] * 100:.1f}% Fraud
            </div>
        </div>
        """
        return "⚠️ Fraudulent ⚠️" if binary_prediction == 1 else "😊 Non-fraudulent 😊", gr.update(value="Decryption successful ✅"), bar_html
        
    except Exception as e:
        logging.error(f"Error during prediction: {e}")
        return "Error during prediction❌", "Error during prediction❌","Error during prediction❌"

def key_already_generated():
    """
    Check if the evaluation keys have already been generated.
    Returns:
        bool: True if the evaluation keys have already been generated, False otherwise.
    """
    global key_already_generated_condition
    if evaluation_keys:
        key_already_generated_condition = True
        return True
    return False

def pre_process_encrypt_send_purchase(*inputs):
    """
    Pre-processes, encrypts, and sends the purchase data for prediction.
    Args:
        *inputs: Variable number of input arguments.
    Returns:
        (str): A short representation of the encrypted input to send in hex. 
    """
    global key_already_generated_condition, encrypted_data
    if key_already_generated_condition == False:
        return None, gr.update(value="Generate your key before. ❌")
    try:
        key_already_generated_condition = True
        logging.info(f"Input Data: {inputs}")

        # Scale the input data
        scaled_data = scaler.transform([list(inputs)])
        logging.info(f"Scaled Data: {scaled_data}")

        # Encrypt the scaled data
        encrypted_data = client.quantize_encrypt_serialize(scaled_data)
        logging.info("Data encrypted successfully.")
        return encrypted_data.hex(), gr.update(value="Inputs are encrypted and sent to server. ✅")
    except Exception as e:
        logging.error(f"Error during pre-processing: {e}")
        return "Error during pre-processing"