PoCInnovation / back /client.py
pierrelissope
final update
743fc77
"""Client script.
This script does the following:
- Query crypto-parameters and pre/post-processing parameters (client.zip)
- Quantize the inputs using the parameters
- Encrypt data using the crypto-parameters
- Send the encrypted data to the server (async using grequests)
- Collect the data and decrypt it
- De-quantize the decrypted results
"""
import io
import os
import sys
from pathlib import Path
import grequests
import numpy
import requests
import torch
import torchvision
import torchvision.transforms as transforms
from concrete.ml.deployment import FHEModelClient
PORT = os.environ.get("PORT", "5000")
IP = os.environ.get("IP", "localhost")
URL = os.environ.get("URL", f"http://{IP}:{PORT}")
NUM_SAMPLES = int(os.environ.get("NUM_SAMPLES", 1))
STATUS_OK = 200
def main():
# Get the necessary data for the client
# client.zip
train_sub_set = ...
zip_response = requests.get(f"{URL}/get_client")
assert zip_response.status_code == STATUS_OK
with open("./client.zip", "wb") as file:
file.write(zip_response.content)
# Get the data to infer
X = train_sub_set[:1]
# Create the client
client = FHEModelClient(path_dir="./", key_dir="./keys")
# The client first need to create the private and evaluation keys.
serialized_evaluation_keys = client.get_serialized_evaluation_keys()
assert isinstance(serialized_evaluation_keys, bytes)
# Evaluation keys can be quite large files but only have to be shared once with the server.
# Check the size of the evaluation keys (in MB)
print(f"Evaluation keys size: {sys.getsizeof(serialized_evaluation_keys) / 1024 / 1024:.2f} MB")
# Update all base64 queries encodings with UploadFile
response = requests.post(
f"{URL}/add_key",
files={"key": io.BytesIO(initial_bytes=serialized_evaluation_keys)},
)
assert response.status_code == STATUS_OK
uid = response.json()["uid"]
inferences = []
# Launch the queries
clear_input = X[[0], :].numpy()
print("Input shape:", clear_input.shape)
assert isinstance(clear_input, numpy.ndarray)
print("Quantize/Encrypt")
encrypted_input = client.quantize_encrypt_serialize(clear_input) # Encrypt the data
assert isinstance(encrypted_input, bytes)
print(f"Encrypted input size: {sys.getsizeof(encrypted_input) / 1024 / 1024:.2f} MB")
print("Posting query")
inferences.append(
grequests.post(
f"{URL}/compute",
files={
"model_input": io.BytesIO(encrypted_input),
},
data={
"uid": uid,
},
)
)
del encrypted_input
del serialized_evaluation_keys
print("Posted!")
# Unpack the results
decrypted_predictions = []
for result in grequests.map(inferences):
if result is None:
raise ValueError(
"Result is None, probably because the server crashed due to lack of available memory."
)
assert result.status_code == STATUS_OK
print("OK!")
encrypted_result = result.content
decrypted_prediction = client.deserialize_decrypt_dequantize(encrypted_result)[0]
decrypted_predictions.append(decrypted_prediction)
print(decrypted_predictions)
if __name__ == "__main__":
main()