PoCInnovation / back /server_deploy.py
pierrelissope
final update
743fc77
"""Deployment server.
Routes:
- Get client.zip
- Add a key
- Compute
"""
import io
import os
import uuid
from pathlib import Path
from typing import Dict
import uvicorn
from fastapi import FastAPI, Form, HTTPException, UploadFile
from fastapi.responses import FileResponse, StreamingResponse
# No relative import here because when not used in the package itself
from concrete.ml.deployment import FHEModelServer
if __name__ == "__main__":
app = FastAPI(debug=False)
FILE_FOLDER = Path(__file__).parent
KEY_PATH = Path(os.environ.get("KEY_PATH", FILE_FOLDER / Path("server_keys")))
CLIENT_SERVER_PATH = Path(os.environ.get("PATH_TO_MODEL", FILE_FOLDER / Path("dev")))
PORT = os.environ.get("PORT", "5000")
fhe = FHEModelServer(str(CLIENT_SERVER_PATH.resolve()))
KEYS: Dict[str, bytes] = {}
PATH_TO_CLIENT = (CLIENT_SERVER_PATH / "client.zip").resolve()
PATH_TO_SERVER = (CLIENT_SERVER_PATH / "server.zip").resolve()
assert PATH_TO_CLIENT.exists()
assert PATH_TO_SERVER.exists()
@app.get("/get_client")
def get_client():
"""Get client.
Returns:
FileResponse: client.zip
Raises:
HTTPException: if the file can't be find locally
"""
path_to_client = (CLIENT_SERVER_PATH / "client.zip").resolve()
if not path_to_client.exists():
raise HTTPException(status_code=500, detail="Could not find client.")
return FileResponse(path_to_client, media_type="application/zip")
@app.post("/add_key")
async def add_key(key: UploadFile):
"""Add public key.
Arguments:
key (UploadFile): public key
Returns:
Dict[str, str]
- uid: uid a personal uid
"""
uid = str(uuid.uuid4())
KEYS[uid] = await key.read()
return {"uid": uid}
@app.post("/compute")
async def compute(model_input: UploadFile, uid: str = Form()): # noqa: B008
"""Compute the circuit over encrypted input.
Arguments:
model_input (UploadFile): input of the circuit
uid (str): uid of the public key to use
Returns:
StreamingResponse: the result of the circuit
"""
key = KEYS[uid]
encrypted_results = fhe.run(
serialized_encrypted_quantized_data=await model_input.read(),
serialized_evaluation_keys=key,
)
return StreamingResponse(
io.BytesIO(encrypted_results),
)
uvicorn.run(app, host="0.0.0.0", port=int(PORT))