private-synthid / wasm-demo.js
jfrery-zama's picture
improve ux
0cdcb6f unverified
raw
history blame
15.5 kB
import initWasm, {
decrypt_serialized_u64_radix_flat_wasm
} from './concrete-ml-extensions-wasm/concrete_ml_extensions_wasm.js';
const SERVER = 'https://api.zama.ai:444';
let clientKey, serverKey;
let encTokens;
let encServerResult;
let keygenWorker;
let encryptWorker;
let sessionUid;
let taskId;
let currentTokenCount = 0; // Track token count for progress estimation
// Memory-efficient base64 encoding for large Uint8Array
function uint8ToBase64(uint8) {
return new Promise((resolve, reject) => {
const blob = new Blob([uint8]);
const reader = new FileReader();
reader.onload = function () {
const base64 = reader.result.split(',')[1];
resolve(base64);
};
reader.onerror = reject;
reader.readAsDataURL(blob);
});
}
const $ = id => document.getElementById(id);
const enable = (id, ok=true) => $(id).disabled = !ok;
const show = (id, visible=true) => $(id).hidden = !visible;
// Progress bar functions
function showProgress() {
const container = $('progressContainer');
const bar = $('progressBar');
const text = $('progressText');
container.style.display = 'block';
bar.style.width = '0%';
text.textContent = '0%';
}
function hideProgress() {
const container = $('progressContainer');
container.style.display = 'none';
}
function updateProgress(percent) {
const bar = $('progressBar');
const text = $('progressText');
bar.style.width = percent + '%';
text.textContent = Math.round(percent) + '%';
}
// Progress animation based on token count
let progressInterval;
function startProgressAnimation(tokenCount) {
const estimatedTimePerToken = 30000; // 30 seconds per token in milliseconds
const totalEstimatedTime = tokenCount * estimatedTimePerToken;
const startTime = Date.now();
progressInterval = setInterval(() => {
const elapsed = Date.now() - startTime;
const progress = Math.min((elapsed / totalEstimatedTime) * 100, 95); // Cap at 95% until actual completion
updateProgress(progress);
// Update status message with time estimate
const remainingTime = Math.max(0, totalEstimatedTime - elapsed);
const remainingSeconds = Math.ceil(remainingTime / 1000);
if (remainingSeconds > 0) {
$('srvProgress').textContent = `Processing ${tokenCount} tokens... (~${remainingSeconds}s remaining)`;
show('srvProgress', true);
}
}, 100); // Update every 100ms for smooth animation
}
function stopProgressAnimation() {
if (progressInterval) {
clearInterval(progressInterval);
progressInterval = null;
}
show('srvProgress', false);
}
// Hide all spinners immediately
show('keygenSpin', false);
show('spin', false);
show('encIcon', false);
show('tokenizerSpin', false);
// Initialize WASM
(async () => {
try {
console.log('[Main] Initializing WASM module...');
await initWasm();
console.log('[Main] WASM module initialized successfully');
// Initialize the keygen worker
keygenWorker = new Worker(new URL('./keygen-worker.js', import.meta.url), { type: 'module' });
keygenWorker.onmessage = async function(e) {
if (e.data.type === 'success') {
const res = e.data.result;
console.log('[Main] Key generation successful');
console.log(`[Main] Client key size: ${res.clientKey.length} bytes`);
console.log(`[Main] Server key size: ${res.serverKey.length} bytes`);
clientKey = res.clientKey; serverKey = res.serverKey;
try {
// Initialize encryption worker
encryptWorker = new Worker(new URL('./encrypt-worker.js', import.meta.url), { type: 'module' });
encryptWorker.onmessage = function(e) {
if (e.data.type === 'ready') {
console.log('[Main] Encryption worker ready');
} else if (e.data.type === 'success') {
encTokens = e.data.result;
console.log(`[Main] Encryption completed: ${encTokens.length} bytes`);
show('encryptSpin', false);
show('encIcon', true);
enable('btnEncrypt', true);
enable('btnSend');
$('encStatus').textContent = `Your text is encrypted 🔒 (${currentTokenCount} tokens)`;
} else if (e.data.type === 'error') {
console.error('[Main] Encryption error:', e.data.error);
show('encryptSpin', false);
enable('btnEncrypt', true);
$('encStatus').textContent = `Encryption failed: ${e.data.error}`;
alert(`Encryption failed: ${e.data.error}`);
}
};
// Initialize the worker with the client key
encryptWorker.postMessage({ type: 'init', clientKey });
console.log('[Main] Sending server key to server...');
$('keygenStatus').textContent = 'Keys generated, sending server key...';
show('keygenSpin', true);
const formData = new FormData();
const serverKeyBlob = new Blob([serverKey], { type: 'application/octet-stream' });
const serverKeyFile = new File([serverKeyBlob], "server.key");
formData.append('key', serverKeyFile);
formData.append('task_name', 'synthid');
const addKeyResponse = await fetch(`${SERVER}/add_key`, {
method: 'POST',
body: formData
});
if (!addKeyResponse.ok) {
const errorText = await addKeyResponse.text();
throw new Error(`Server /add_key failed: ${addKeyResponse.status} ${errorText}`);
}
const { uid } = await addKeyResponse.json();
sessionUid = uid;
console.log('[Main] Server key sent and UID received:', sessionUid);
$('keygenStatus').textContent = 'Keys generated & UID received ✓';
enable('btnEncrypt');
} catch (error) {
console.error('[Main] Server key submission error:', error);
$('keygenStatus').textContent = `Server key submission failed: ${error.message}`;
enable('btnEncrypt', false);
} finally {
show('keygenSpin', false);
}
} else {
console.error('[Main] Key generation error:', e.data.error);
$('keygenStatus').textContent = `Error generating keys: ${e.data.error}`;
show('keygenSpin', false);
}
};
} catch (e) {
console.error('[Main] Failed to initialize WASM module:', e);
$('keygenStatus').textContent = `Initialization Error: ${e.message}`;
throw e;
}
})();
$('btnKeygen').onclick = async () => {
if ($('keygenSpin').hidden === false) {
console.log('[Main] Keygen already in progress, ignoring click');
return;
}
show('keygenSpin', true);
$('keygenStatus').textContent = 'generating…';
try {
keygenWorker.postMessage({});
} catch (e) {
console.error('[Main] Key generation error:', e);
$('keygenStatus').textContent = `Error generating keys: ${e.message}`;
show('keygenSpin', false);
}
};
$('btnEncrypt').onclick = async () => {
const text = $('tokenInput').value.trim();
if (!text) {
console.error('[Main] No text provided for tokenization/encryption');
alert('Please enter text to encrypt.');
return;
}
if (!encryptWorker) {
console.error('[Main] Encryption worker not initialized');
alert('Encryption worker is not ready. Please generate keys first.');
return;
}
show('encryptSpin', true);
show('encIcon', false);
enable('btnEncrypt', false);
try {
console.log('[Main] Tokenizing text:', text);
const tokenIds = llama3Tokenizer.encode(text);
currentTokenCount = tokenIds.length; // Store token count
console.log('[Main] Token IDs:', tokenIds);
encryptWorker.postMessage({ type: 'encrypt', tokenIds });
} catch (error) {
console.error('[Main] Tokenization or encryption initiation error:', error);
show('encryptSpin', false);
enable('btnEncrypt', true);
alert(`Error during tokenization/encryption: ${error.message}`);
}
};
async function pollTaskStatus(currentTaskId, currentUid) {
try {
const statusResponse = await fetch(`${SERVER}/get_task_status?task_id=${currentTaskId}&uid=${currentUid}`);
if (!statusResponse.ok) {
const errorText = await statusResponse.text();
console.error(`[Poll] Error fetching status: ${statusResponse.status} ${errorText}`);
$('srvStatus').textContent = `Status check error: ${statusResponse.status}`;
show('spin', false);
stopProgressAnimation();
hideProgress();
return null;
}
const statusData = await statusResponse.json();
console.log('[Poll] Task status:', statusData);
$('srvStatus').textContent = `Status: ${statusData.status} - ${statusData.details}`;
if (statusData.status === 'success' || statusData.status === 'completed') {
updateProgress(100); // Complete the progress bar
return statusData;
} else if (['failure', 'revoked', 'unknown', 'error'].includes(statusData.status.toLowerCase())) {
console.error('[Poll] Task failed or unrecoverable:', statusData);
$('srvStatus').textContent = `Task failed: ${statusData.status}`;
show('spin', false);
stopProgressAnimation();
hideProgress();
return null;
} else {
setTimeout(() => pollTaskStatus(currentTaskId, currentUid).then(finalStatus => {
if (finalStatus && (finalStatus.status === 'success' || finalStatus.status === 'completed')) {
getTaskResult(currentTaskId, currentUid, 'synthid');
}
}), 5000);
return null;
}
} catch (e) {
console.error('[Poll] Polling exception:', e);
$('srvStatus').textContent = `Polling error: ${e.message}`;
show('spin', false);
stopProgressAnimation();
hideProgress();
return null;
}
}
async function getTaskResult(currentTaskId, currentUid, taskName) {
$('srvStatus').textContent = 'Fetching result...';
try {
const resultResponse = await fetch(`${SERVER}/get_task_result?task_name=${taskName}&task_id=${currentTaskId}&uid=${currentUid}`);
if (!resultResponse.ok) {
const errorText = await resultResponse.text();
throw new Error(`Server /get_task_result error: ${resultResponse.status} ${errorText}`);
}
const resultArrayBuffer = await resultResponse.arrayBuffer();
encServerResult = new Uint8Array(resultArrayBuffer);
console.log(`[Main] Received encrypted result: ${encServerResult.length} bytes`);
$('encResult').value = `(${encServerResult.length} B)`;
$('srvStatus').textContent = `✓ result received (${((performance.now() - window.taskStartTime) / 1000).toFixed(2)}s total)`;
enable('btnDecrypt');
} catch (e) {
const duration = window.taskStartTime ? ((performance.now() - window.taskStartTime) / 1000).toFixed(2) : 'N/A';
console.error(`[Main] /get_task_result failed after ${duration}s:`, e);
$('srvStatus').textContent = `Result fetch error: ${e.message} (${duration}s)`;
} finally {
show('spin', false);
$('srvComputing').hidden = true;
stopProgressAnimation();
hideProgress();
}
}
$('btnSend').onclick = async () => {
if ($('spin').hidden === false) {
console.log('[Main] Task submission/polling already in progress, ignoring click');
return;
}
if (!sessionUid || !encTokens) {
alert('Please generate keys and encrypt text first.');
return;
}
show('encIcon', false);
show('spin', true);
$('srvStatus').textContent = 'Submitting task…';
$('srvComputing').hidden = true;
window.taskStartTime = performance.now();
// Show and start progress animation
showProgress();
startProgressAnimation(currentTokenCount);
try {
const formData = new FormData();
formData.append('uid', sessionUid);
formData.append('task_name', 'synthid');
const encryptedInputBlob = new Blob([encTokens], { type: 'application/octet-stream' });
const encryptedInputFile = new File([encryptedInputBlob], "input.fheencrypted");
formData.append('encrypted_input', encryptedInputFile);
const startTaskResponse = await fetch(`${SERVER}/start_task`, {
method: 'POST',
body: formData
});
if (!startTaskResponse.ok) {
const errorText = await startTaskResponse.text();
throw new Error(`Server /start_task error: ${startTaskResponse.status} ${errorText}`);
}
const { task_id: newTaskId } = await startTaskResponse.json();
taskId = newTaskId;
console.log('[Main] Task submitted to server. Task ID:', taskId);
$('srvStatus').textContent = `Task submitted (ID: ${taskId.substring(0,8)}...). Polling status...`;
$('srvComputing').hidden = false;
pollTaskStatus(taskId, sessionUid).then(finalStatus => {
if (finalStatus && (finalStatus.status === 'success' || finalStatus.status === 'completed')) {
getTaskResult(taskId, sessionUid, 'synthid');
}
});
} catch (e) {
const duration = ((performance.now() - window.taskStartTime) / 1000).toFixed(2);
console.error(`[Main] Task submission failed after ${duration}s:`, e);
$('srvStatus').textContent = `Task submission error: ${e.message} (${duration}s)`;
show('spin', false);
$('srvComputing').hidden = true;
stopProgressAnimation();
hideProgress();
}
};
$('btnDecrypt').onclick = () => {
try {
console.log('[Main] Starting decryption...');
const dec = decrypt_serialized_u64_radix_flat_wasm(encServerResult, clientKey);
const [flag, score_scaled, total_g] = Array.from(dec);
const score = (Number(score_scaled) / 1e6).toFixed(6);
console.log('[Main] Decryption successful');
console.log(`[Main] Result - flag: ${flag}, score: ${score}, total_g: ${total_g}`);
$('decResult').textContent = `Flag: ${flag}, Score: ${score}, Total G: ${total_g}`;
} catch (e) {
console.error('[Main] Decryption error:', e);
$('decResult').textContent = `Decryption failed: ${e.message}`;
}
};