chore: adding a DB system
Browse files- handler.py +37 -7
- play_with_endpoint.py +21 -7
handler.py
CHANGED
@@ -20,6 +20,9 @@ class EndpointHandler:
|
|
20 |
# For server
|
21 |
self.fhemodel_server = FHEModelServer(path + "/compiled_model")
|
22 |
|
|
|
|
|
|
|
23 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
24 |
"""
|
25 |
data args:
|
@@ -29,13 +32,40 @@ class EndpointHandler:
|
|
29 |
A :obj:`list` | `dict`: will be serialized and returned
|
30 |
"""
|
31 |
|
32 |
-
# Get
|
33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
evaluation_keys = from_json(data.pop("evaluation_keys", data))
|
37 |
|
38 |
-
|
39 |
-
encrypted_prediction = self.fhemodel_server.run(encrypted_inputs, evaluation_keys)
|
40 |
|
41 |
-
|
|
|
20 |
# For server
|
21 |
self.fhemodel_server = FHEModelServer(path + "/compiled_model")
|
22 |
|
23 |
+
# Simulate a database of keys
|
24 |
+
self.key_database = {}
|
25 |
+
|
26 |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
|
27 |
"""
|
28 |
data args:
|
|
|
32 |
A :obj:`list` | `dict`: will be serialized and returned
|
33 |
"""
|
34 |
|
35 |
+
# Get method
|
36 |
+
method = data.pop("method", data)
|
37 |
+
|
38 |
+
if method == "save_key":
|
39 |
+
|
40 |
+
# Get keys
|
41 |
+
evaluation_keys = from_json(data.pop("evaluation_keys", data))
|
42 |
+
|
43 |
+
uid = np.random.randint(2**32)
|
44 |
+
|
45 |
+
while uid in self.key_database.keys():
|
46 |
+
uid = np.random.randint(2**32)
|
47 |
+
|
48 |
+
self.key_database[uid] = evaluation_keys
|
49 |
+
|
50 |
+
return {"uid": uid}
|
51 |
+
|
52 |
+
elif method == "inference":
|
53 |
+
|
54 |
+
uid = data.pop("uid", data)
|
55 |
+
|
56 |
+
assert uid in self.key_database.keys(), f"{uid} not in DB, {self.key_database.keys()=}"
|
57 |
+
|
58 |
+
# Get inputs
|
59 |
+
encrypted_inputs = from_json(data.pop("encrypted_inputs", data))
|
60 |
+
|
61 |
+
# Find key in the database
|
62 |
+
evaluation_keys = self.key_database[uid]
|
63 |
+
|
64 |
+
# Run CML prediction
|
65 |
+
encrypted_prediction = self.fhemodel_server.run(encrypted_inputs, evaluation_keys)
|
66 |
|
67 |
+
return to_json(encrypted_prediction)
|
|
|
68 |
|
69 |
+
else:
|
|
|
70 |
|
71 |
+
return
|
play_with_endpoint.py
CHANGED
@@ -29,6 +29,10 @@ headers = {
|
|
29 |
|
30 |
def query(payload):
|
31 |
response = requests.post(API_URL, headers=headers, json=payload)
|
|
|
|
|
|
|
|
|
32 |
return response.json()
|
33 |
|
34 |
|
@@ -60,6 +64,16 @@ fhemodel_client = FHEModelClient(path_to_model)
|
|
60 |
fhemodel_client.generate_private_and_evaluation_keys()
|
61 |
evaluation_keys = fhemodel_client.get_serialized_evaluation_keys()
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
# Test the handler
|
64 |
nb_good = 0
|
65 |
nb_samples = len(X_test)
|
@@ -72,17 +86,17 @@ for i in range(nb_samples):
|
|
72 |
# Quantize the input and encrypt it
|
73 |
encrypted_inputs = fhemodel_client.quantize_encrypt_serialize(X_test[i].reshape(1, -1))
|
74 |
|
75 |
-
|
76 |
-
print(f"Size of encrypted input: {sys.getsizeof(encrypted_inputs) / 1024 / 1024} megabytes")
|
77 |
-
print(f"Size of keys: {sys.getsizeof(evaluation_keys) / 1024 / 1024} megabytes")
|
78 |
-
|
79 |
-
# Prepare the payload, including the evaluation keys which are needed server side
|
80 |
payload = {
|
81 |
"inputs": "fake",
|
82 |
"encrypted_inputs": to_json(encrypted_inputs),
|
83 |
-
"
|
|
|
84 |
}
|
85 |
|
|
|
|
|
|
|
86 |
# Run the inference on HF servers
|
87 |
duration -= time.time()
|
88 |
duration_inference = -time.time()
|
@@ -98,7 +112,7 @@ for i in range(nb_samples):
|
|
98 |
|
99 |
if verbose or True:
|
100 |
print(
|
101 |
-
f"for {i}-th input, {prediction=} with expected {Y_test[i]} in {duration_inference} seconds"
|
102 |
)
|
103 |
|
104 |
# Measure accuracy
|
|
|
29 |
|
30 |
def query(payload):
|
31 |
response = requests.post(API_URL, headers=headers, json=payload)
|
32 |
+
|
33 |
+
if "error" in response:
|
34 |
+
assert False, f"Got an error: {response=}"
|
35 |
+
|
36 |
return response.json()
|
37 |
|
38 |
|
|
|
64 |
fhemodel_client.generate_private_and_evaluation_keys()
|
65 |
evaluation_keys = fhemodel_client.get_serialized_evaluation_keys()
|
66 |
|
67 |
+
# Save the key in the database
|
68 |
+
payload = {
|
69 |
+
"inputs": "fake",
|
70 |
+
"evaluation_keys": to_json(evaluation_keys),
|
71 |
+
"method": "save_key",
|
72 |
+
}
|
73 |
+
|
74 |
+
uid = query(payload)["uid"]
|
75 |
+
print(f"Storing the key in the database under {uid=}")
|
76 |
+
|
77 |
# Test the handler
|
78 |
nb_good = 0
|
79 |
nb_samples = len(X_test)
|
|
|
86 |
# Quantize the input and encrypt it
|
87 |
encrypted_inputs = fhemodel_client.quantize_encrypt_serialize(X_test[i].reshape(1, -1))
|
88 |
|
89 |
+
# Prepare the payload
|
|
|
|
|
|
|
|
|
90 |
payload = {
|
91 |
"inputs": "fake",
|
92 |
"encrypted_inputs": to_json(encrypted_inputs),
|
93 |
+
"method": "inference",
|
94 |
+
"uid": uid,
|
95 |
}
|
96 |
|
97 |
+
if verbose or True:
|
98 |
+
print(f"Size of the payload: {sys.getsizeof(payload) / 1024} kilobytes")
|
99 |
+
|
100 |
# Run the inference on HF servers
|
101 |
duration -= time.time()
|
102 |
duration_inference = -time.time()
|
|
|
112 |
|
113 |
if verbose or True:
|
114 |
print(
|
115 |
+
f"for {i}-th input, {prediction=} with expected {Y_test[i]} in {duration_inference:.3f} seconds"
|
116 |
)
|
117 |
|
118 |
# Measure accuracy
|