Spaces:
Sleeping
Sleeping
"""Deployment server. | |
Routes: | |
- Get client.zip | |
- Add a key | |
- Compute | |
""" | |
import io | |
import os | |
import uuid | |
from pathlib import Path | |
from typing import Dict | |
import uvicorn | |
from fastapi import FastAPI, Form, HTTPException, UploadFile | |
from fastapi.responses import FileResponse, StreamingResponse | |
# No relative import here because when not used in the package itself | |
from concrete.ml.deployment import FHEModelServer | |
if __name__ == "__main__": | |
app = FastAPI(debug=False) | |
FILE_FOLDER = Path(__file__).parent | |
KEY_PATH = Path(os.environ.get("KEY_PATH", FILE_FOLDER / Path("server_keys"))) | |
CLIENT_SERVER_PATH = Path(os.environ.get("PATH_TO_MODEL", FILE_FOLDER / Path("dev"))) | |
PORT = os.environ.get("PORT", "5000") | |
fhe = FHEModelServer(str(CLIENT_SERVER_PATH.resolve())) | |
KEYS: Dict[str, bytes] = {} | |
PATH_TO_CLIENT = (CLIENT_SERVER_PATH / "client.zip").resolve() | |
PATH_TO_SERVER = (CLIENT_SERVER_PATH / "server.zip").resolve() | |
assert PATH_TO_CLIENT.exists() | |
assert PATH_TO_SERVER.exists() | |
def get_client(): | |
"""Get client. | |
Returns: | |
FileResponse: client.zip | |
Raises: | |
HTTPException: if the file can't be find locally | |
""" | |
path_to_client = (CLIENT_SERVER_PATH / "client.zip").resolve() | |
if not path_to_client.exists(): | |
raise HTTPException(status_code=500, detail="Could not find client.") | |
return FileResponse(path_to_client, media_type="application/zip") | |
async def add_key(key: UploadFile): | |
"""Add public key. | |
Arguments: | |
key (UploadFile): public key | |
Returns: | |
Dict[str, str] | |
- uid: uid a personal uid | |
""" | |
uid = str(uuid.uuid4()) | |
KEYS[uid] = await key.read() | |
return {"uid": uid} | |
async def compute(model_input: UploadFile, uid: str = Form()): # noqa: B008 | |
"""Compute the circuit over encrypted input. | |
Arguments: | |
model_input (UploadFile): input of the circuit | |
uid (str): uid of the public key to use | |
Returns: | |
StreamingResponse: the result of the circuit | |
""" | |
key = KEYS[uid] | |
encrypted_results = fhe.run( | |
serialized_encrypted_quantized_data=await model_input.read(), | |
serialized_evaluation_keys=key, | |
) | |
return StreamingResponse( | |
io.BytesIO(encrypted_results), | |
) | |
uvicorn.run(app, host="0.0.0.0", port=int(PORT)) | |