binoua commited on
Commit
aec7071
·
1 Parent(s): 55714b8

chore: adding a DB system

Browse files
Files changed (2) hide show
  1. handler.py +37 -7
  2. play_with_endpoint.py +21 -7
handler.py CHANGED
@@ -20,6 +20,9 @@ class EndpointHandler:
20
  # For server
21
  self.fhemodel_server = FHEModelServer(path + "/compiled_model")
22
 
 
 
 
23
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
24
  """
25
  data args:
@@ -29,13 +32,40 @@ class EndpointHandler:
29
  A :obj:`list` | `dict`: will be serialized and returned
30
  """
31
 
32
- # Get inputs
33
- encrypted_inputs = from_json(data.pop("encrypted_inputs", data))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
- # Get keys
36
- evaluation_keys = from_json(data.pop("evaluation_keys", data))
37
 
38
- # Run CML prediction
39
- encrypted_prediction = self.fhemodel_server.run(encrypted_inputs, evaluation_keys)
40
 
41
- return to_json(encrypted_prediction)
 
20
  # For server
21
  self.fhemodel_server = FHEModelServer(path + "/compiled_model")
22
 
23
+ # Simulate a database of keys
24
+ self.key_database = {}
25
+
26
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
27
  """
28
  data args:
 
32
  A :obj:`list` | `dict`: will be serialized and returned
33
  """
34
 
35
+ # Get method
36
+ method = data.pop("method", data)
37
+
38
+ if method == "save_key":
39
+
40
+ # Get keys
41
+ evaluation_keys = from_json(data.pop("evaluation_keys", data))
42
+
43
+ uid = np.random.randint(2**32)
44
+
45
+ while uid in self.key_database.keys():
46
+ uid = np.random.randint(2**32)
47
+
48
+ self.key_database[uid] = evaluation_keys
49
+
50
+ return {"uid": uid}
51
+
52
+ elif method == "inference":
53
+
54
+ uid = data.pop("uid", data)
55
+
56
+ assert uid in self.key_database.keys(), f"{uid} not in DB, {self.key_database.keys()=}"
57
+
58
+ # Get inputs
59
+ encrypted_inputs = from_json(data.pop("encrypted_inputs", data))
60
+
61
+ # Find key in the database
62
+ evaluation_keys = self.key_database[uid]
63
+
64
+ # Run CML prediction
65
+ encrypted_prediction = self.fhemodel_server.run(encrypted_inputs, evaluation_keys)
66
 
67
+ return to_json(encrypted_prediction)
 
68
 
69
+ else:
 
70
 
71
+ return
play_with_endpoint.py CHANGED
@@ -29,6 +29,10 @@ headers = {
29
 
30
  def query(payload):
31
  response = requests.post(API_URL, headers=headers, json=payload)
 
 
 
 
32
  return response.json()
33
 
34
 
@@ -60,6 +64,16 @@ fhemodel_client = FHEModelClient(path_to_model)
60
  fhemodel_client.generate_private_and_evaluation_keys()
61
  evaluation_keys = fhemodel_client.get_serialized_evaluation_keys()
62
 
 
 
 
 
 
 
 
 
 
 
63
  # Test the handler
64
  nb_good = 0
65
  nb_samples = len(X_test)
@@ -72,17 +86,17 @@ for i in range(nb_samples):
72
  # Quantize the input and encrypt it
73
  encrypted_inputs = fhemodel_client.quantize_encrypt_serialize(X_test[i].reshape(1, -1))
74
 
75
- if verbose:
76
- print(f"Size of encrypted input: {sys.getsizeof(encrypted_inputs) / 1024 / 1024} megabytes")
77
- print(f"Size of keys: {sys.getsizeof(evaluation_keys) / 1024 / 1024} megabytes")
78
-
79
- # Prepare the payload, including the evaluation keys which are needed server side
80
  payload = {
81
  "inputs": "fake",
82
  "encrypted_inputs": to_json(encrypted_inputs),
83
- "evaluation_keys": to_json(evaluation_keys),
 
84
  }
85
 
 
 
 
86
  # Run the inference on HF servers
87
  duration -= time.time()
88
  duration_inference = -time.time()
@@ -98,7 +112,7 @@ for i in range(nb_samples):
98
 
99
  if verbose or True:
100
  print(
101
- f"for {i}-th input, {prediction=} with expected {Y_test[i]} in {duration_inference} seconds"
102
  )
103
 
104
  # Measure accuracy
 
29
 
30
  def query(payload):
31
  response = requests.post(API_URL, headers=headers, json=payload)
32
+
33
+ if "error" in response:
34
+ assert False, f"Got an error: {response=}"
35
+
36
  return response.json()
37
 
38
 
 
64
  fhemodel_client.generate_private_and_evaluation_keys()
65
  evaluation_keys = fhemodel_client.get_serialized_evaluation_keys()
66
 
67
+ # Save the key in the database
68
+ payload = {
69
+ "inputs": "fake",
70
+ "evaluation_keys": to_json(evaluation_keys),
71
+ "method": "save_key",
72
+ }
73
+
74
+ uid = query(payload)["uid"]
75
+ print(f"Storing the key in the database under {uid=}")
76
+
77
  # Test the handler
78
  nb_good = 0
79
  nb_samples = len(X_test)
 
86
  # Quantize the input and encrypt it
87
  encrypted_inputs = fhemodel_client.quantize_encrypt_serialize(X_test[i].reshape(1, -1))
88
 
89
+ # Prepare the payload
 
 
 
 
90
  payload = {
91
  "inputs": "fake",
92
  "encrypted_inputs": to_json(encrypted_inputs),
93
+ "method": "inference",
94
+ "uid": uid,
95
  }
96
 
97
+ if verbose or True:
98
+ print(f"Size of the payload: {sys.getsizeof(payload) / 1024} kilobytes")
99
+
100
  # Run the inference on HF servers
101
  duration -= time.time()
102
  duration_inference = -time.time()
 
112
 
113
  if verbose or True:
114
  print(
115
+ f"for {i}-th input, {prediction=} with expected {Y_test[i]} in {duration_inference:.3f} seconds"
116
  )
117
 
118
  # Measure accuracy