File size: 4,530 Bytes
35199db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a288dbb
35199db
 
5ba534e
35199db
 
 
762796b
5ba534e
aec7071
d161508
762796b
 
 
 
 
 
 
aec7071
35199db
 
 
 
 
10021fb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35199db
 
 
 
 
 
 
aec7071
0f5c4ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d161508
0f5c4ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aec7071
0f5c4ad
aec7071
35199db
 
 
bb69c02
35199db
 
bb69c02
35199db
 
 
 
10021fb
35199db
aec7071
35199db
 
5ba534e
aec7071
 
35199db
 
bb69c02
0f5c4ad
bb69c02
aec7071
35199db
 
a288dbb
35199db
 
a288dbb
35199db
c0f39aa
35199db
 
 
 
 
bb69c02
a288dbb
aec7071
a288dbb
35199db
 
c0f39aa
35199db
 
bb69c02
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import numpy as np
import time
import os, sys

from pathlib import Path

from concrete.ml.deployment import FHEModelClient

import requests


def to_json(python_object):
    if isinstance(python_object, bytes):
        return {"__class__": "bytes", "__value__": list(python_object)}
    raise TypeError(repr(python_object) + " is not JSON serializable")


def from_json(python_object):
    if "__class__" in python_object:
        return bytes(python_object["__value__"])


API_URL = "https://h0cvbig1fkmf57eb.eu-west-1.aws.endpoints.huggingface.cloud"
headers = {
    "Authorization": "Bearer " + os.environ.get("HF_TOKEN"),
    "Content-Type": "application/json",
}


def query(payload, allowed_retries=2):
    response = requests.post(API_URL, headers=headers, json=payload)

    if response.json() is not None and "error" in response.json():
        if allowed_retries > 0:
            # Sometimes we have "Bad gateway" error
            print(f"Warning, error {response=} {response.json()=} in the query, relaunching")

            return query(payload, allowed_retries - 1)

        assert False, f"Got an error: {response=} {response.json()=}"

    return response.json()


path_to_model = Path("compiled_model")

# Decision-tree in FHE
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
import numpy

features, classes = fetch_openml(data_id=44, as_frame=False, cache=True, return_X_y=True)
classes = classes.astype(numpy.int64)

_, X_test, _, Y_test = train_test_split(
    features,
    classes,
    test_size=0.15,
    random_state=42,
)

# Recover parameters for client side
fhemodel_client = FHEModelClient(path_to_model)

# Generate the keys
fhemodel_client.generate_private_and_evaluation_keys()
evaluation_keys = fhemodel_client.get_serialized_evaluation_keys()

# Save the key in the database
evaluation_keys_remaining = evaluation_keys[:]
uid = None
is_first = True
is_finished = False
i = 0
packet_size = 1024 * 1024 * 100

while not is_finished:

    # Send by packets of 100M
    if sys.getsizeof(evaluation_keys_remaining) > packet_size:
        evaluation_keys_piece = evaluation_keys_remaining[:packet_size]
        evaluation_keys_remaining = evaluation_keys_remaining[packet_size:]
    else:
        evaluation_keys_piece = evaluation_keys_remaining
        evaluation_keys_remaining = None
        is_finished = True

    print(
        f"Sending {i}-th piece of the key (remaining size is {sys.getsizeof(evaluation_keys_remaining) / 1024:.2f} kbytes)"
    )
    i += 1

    if is_first:
        is_first = False
        payload = {
            "inputs": "fake",
            "evaluation_keys": to_json(evaluation_keys_piece),
            "method": "save_key",
        }

        uid = query(payload)["uid"]
        print(f"Storing the key in the database under {uid=}")

    else:
        payload = {
            "inputs": "fake",
            "evaluation_keys": to_json(evaluation_keys_piece),
            "method": "append_key",
            "uid": uid,
        }

        query(payload)

# Test the handler
nb_good = 0
nb_samples = len(X_test)
verbose = True
time_start = time.time()
duration = 0
is_first = True

for i in range(nb_samples):

    # Quantize the input and encrypt it
    encrypted_inputs = fhemodel_client.quantize_encrypt_serialize(X_test[i].reshape(1, -1))

    # Prepare the payload
    payload = {
        "inputs": "fake",
        "encrypted_inputs": to_json(encrypted_inputs),
        "method": "inference",
        "uid": uid,
    }

    if is_first:
        print(f"Size of the payload: {sys.getsizeof(payload) / 1024:.2f} kilobytes")
        is_first = False

    # Run the inference on HF servers
    duration -= time.time()
    duration_inference = -time.time()
    encrypted_prediction = query(payload)
    duration += time.time()
    duration_inference += time.time()

    encrypted_prediction = from_json(encrypted_prediction)

    # Decrypt the result and dequantize
    prediction_proba = fhemodel_client.deserialize_decrypt_dequantize(encrypted_prediction)[0]
    prediction = np.argmax(prediction_proba)

    if verbose:
        print(
            f"for {i}-th input, {prediction=} with expected {Y_test[i]} in {duration_inference:.3f} seconds"
        )

    # Measure accuracy
    nb_good += Y_test[i] == prediction

print(f"Accuracy on {nb_samples} samples is {nb_good * 1. / nb_samples}")
print(f"Total time: {time.time() - time_start:.3f} seconds")
print(f"Duration per inference: {duration / nb_samples:.3f} seconds")