import React, { useState, useEffect, useRef } from 'react'; import * as d3 from 'd3'; import { useTheme } from '../context/themeContext'; import MODELS from '../utils/models'; import DEVICES from '../utils/devices'; type Precision = 'fp32' | 'fp16' | 'int8' | 'int4'; interface ModelSizeBarChartProps { modelSize: number; // in GB largestModelSize: number; // largest model in full precision (fp32) modelPrecision: Precision; deviceMemorySet: boolean; activationMemorySize?: number; // optional for standard calculator } interface InferenceRuntimeLineChartProps { availableMemory: AvailableMemory; // in GB memoryPerInput: number; // in GB } interface LineChartData { seqLength: number; batchSize: number; } interface AvailableMemory { int4: number; int8: number; fp16: number; fp32: number; } // Utility to determine color based on precision function chooseColor(precision: Precision) { const colors = { fp32: '#e45f5b', fp16: '#ffc068', int8: '#71cce9', int4: '#383d95', }; return colors[precision] || 'gray'; } // Utility function to calculate total memory with precision factor for prefill chunking function calculateTotalMemory( modelParams: number, hiddenSize: number, numLayers: number, intermediateSize: number, precision: Precision ) { const precisionFactor = { fp32: 4, fp16: 2, int8: 1, int4: 0.5, }; const memoryPerInput = (4 * hiddenSize * numLayers) / 1_000_000_000; // GB const modelMemorySize = modelParams * precisionFactor[precision]; // Adjusted by precision const activationMemorySize = Math.max(2 * intermediateSize, 4 * hiddenSize) / 1_000_000_000; // GB return memoryPerInput + modelMemorySize + activationMemorySize; } // Bar chart for model footprint (shared by both standard and prefill chunking calculators) function ModelSizeBarChart({ modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize = 0, // default to 0 for standard calculator }: ModelSizeBarChartProps) { const { theme } = useTheme(); const chartRef = useRef(null); const width = 600; const height = 50; useEffect(() => { if (modelSize > 0 && largestModelSize > 0) { d3.select(chartRef.current).selectAll('*').remove(); const svg = d3.select(chartRef.current).attr('width', width).attr('height', height); const xScale = d3.scaleLinear().domain([0, largestModelSize]).range([0, width]); if (modelSize + activationMemorySize > largestModelSize) { svg .append('rect') .attr('x', 0) .attr('y', 0) .attr('width', width) .attr('height', height) .attr('fill', 'transparent') .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26') .style('stroke-dasharray', '4, 4') .style('stroke-width', '2px'); svg .append('text') .attr('x', width / 2) .attr('y', height / 2) .attr('text-anchor', 'middle') .attr('alignment-baseline', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Out of Memory'); } else { svg .append('rect') .attr('x', 0) .attr('y', 0) .attr('width', xScale(modelSize)) .attr('height', height) .attr('fill', chooseColor(modelPrecision)); if (activationMemorySize > 0) { svg .append('rect') .attr('x', xScale(modelSize)) .attr('y', 0) .attr('width', xScale(activationMemorySize)) .attr('height', height) .attr('fill', '#a4b8e0'); } if (deviceMemorySet) { svg .append('rect') .attr('x', xScale(modelSize + activationMemorySize)) .attr('y', 0) .attr('width', xScale(largestModelSize - (modelSize + activationMemorySize))) .attr('height', height) .attr('fill', 'transparent') .style('stroke', chooseColor(modelPrecision)) .style('stroke-width', '2px'); } } } }, [modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize, theme]); return ; } // Line chart for inference runtime (shared by both standard and prefill chunking calculators) function InferenceRuntimeLineChart({ availableMemory, memoryPerInput, }: InferenceRuntimeLineChartProps) { const { theme } = useTheme(); const chartRef = useRef(null); const maxSeqLength = 4096; const maxBatchSize = 128; useEffect(() => { if (memoryPerInput > 0 && Object.values(availableMemory).some((val) => val > 0)) { const margin = { top: 20, right: 20, bottom: 50, left: 50 }; const width = 600 - margin.left - margin.right; const height = 400 - margin.top - margin.bottom; const precisions = [ { name: 'FP32', color: '#e45f5b' }, { name: 'FP16', color: '#ffc068' }, { name: 'INT8', color: '#71cce9' }, { name: 'INT4', color: '#383d95' }, ]; const svg = d3.select(chartRef.current); svg.selectAll('*').remove(); const xScale = d3.scaleLinear().domain([0, maxSeqLength]).range([0, width]); const yScale = d3.scaleLinear().domain([0, maxBatchSize]).range([height, 0]); const xAxis = d3.axisBottom(xScale); const yAxis = d3.axisLeft(yScale); const zoom = d3 .zoom() .scaleExtent([0.5, 10]) .translateExtent([ [-width, -height], [2 * width, 2 * height], ]) .on('zoom', (event) => { const transform = event.transform; svg.select('.x-axis').call(xAxis.scale(transform.rescaleX(xScale))); svg.select('.y-axis').call(yAxis.scale(transform.rescaleY(yScale))); svg.selectAll('path').attr('transform', transform); }); svg .attr('width', width + margin.left + margin.right) .attr('height', height + margin.top + margin.bottom) .append('g') .attr('transform', `translate(${margin.left}, ${margin.top})`) .call(zoom); svg .append('g') .attr('class', 'x-axis') .attr('transform', `translate(${margin.left}, ${height + margin.top})`) .call(xAxis); svg .append('g') .attr('class', 'y-axis') .attr('transform', `translate(${margin.left}, ${margin.top})`) .call(yAxis); svg .append('text') .attr('transform', `translate(${width / 2 + margin.left}, ${height + margin.top + 40})`) .style('text-anchor', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Sequence Length'); svg .append('text') .attr('transform', `rotate(-90)`) .attr('y', 0) .attr('x', 0 - height / 2 - margin.top) .attr('dy', '1em') .style('text-anchor', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Batch Size'); // Adding legend for precisions const legend = svg .append('g') .attr('class', 'legend') .attr('transform', `translate(${width - 20}, 20)`); precisions.forEach((precision, index) => { const legendItem = legend.append('g').attr('transform', `translate(0, ${index * 30})`); legendItem .append('rect') .attr('x', 10) .attr('y', 10) .attr('width', 10) .attr('height', 10) .style('fill', precision.color); legendItem .append('text') .attr('x', 30) .attr('y', 16) .text(precision.name) .style('font-size', '16px') .style('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .attr('alignment-baseline', 'middle'); }); legend .append('rect') .attr('class', 'legend-box') .attr('width', 80) .attr('height', precisions.length * 30) .style('fill', 'none') .style('stroke-width', '1px') .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26'); const tooltip = d3.select('#tooltip'); for (const [precision, memory] of Object.entries(availableMemory)) { const sequenceLengths = d3 .range(1, maxSeqLength, 1) .map((seqLength) => ({ seqLength, batchSize: memory / (seqLength * memoryPerInput), })) .filter((d) => d.batchSize <= maxBatchSize && d.batchSize > 1 && d.seqLength > 1); const lineGroup = svg .append('g') .attr('transform', `translate(${margin.left}, ${margin.top})`); const line = d3 .line() .x((d) => xScale(d.seqLength)) .y((d) => yScale(d.batchSize)) .curve(d3.curveBasis); lineGroup .append('path') .datum(sequenceLengths) .attr('fill', 'none') .attr('stroke', chooseColor(precision as Precision)) .attr('stroke-width', 4) .attr('d', line) .on('mouseover', () => { tooltip.style('opacity', 1); tooltip.style('background-color', theme === 'dark' ? '#181f26' : '#f9fafb'); }) .on('mousemove', (event) => { tooltip.selectAll('text').remove(); const [x, y] = d3.pointer(event); const xValue = xScale.invert(x); const yValue = yScale.invert(y); tooltip .html(`Sequence Length: ${xValue.toFixed(0)}
Batch Size: ${yValue.toFixed(0)}`) .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .style('left', event.pageX + 10 + 'px') .style('top', event.pageY + 10 + 'px'); }) .on('mouseout', () => { tooltip.style('opacity', 0); }); } } }, [availableMemory, memoryPerInput, theme]); return ( <>
); } // Prefill Chunking Calculator with Updated Logic and Precision Adjustment function PrefillChunkingCalculator({ deviceMemory, modelParams, hiddenSize, numLayers, intermediateSize, }: { deviceMemory: number; modelParams: number; hiddenSize: number; numLayers: number; intermediateSize: number; }) { if (!deviceMemory || !modelParams || !hiddenSize || !numLayers || !intermediateSize) { return null; } return ( <> {/* Model Footprint with Prefill Chunking */}
Model Footprint with Prefill Chunking
{(['fp32', 'fp16', 'int8', 'int4'] as Precision[]).map((precision) => { const totalMemory = calculateTotalMemory( modelParams, hiddenSize, numLayers, intermediateSize, precision ); return (
{precision.toUpperCase()}
0} activationMemorySize={ Math.max(2 * intermediateSize, 4 * hiddenSize) / 1_000_000_000 } />
{totalMemory.toFixed(2)} / {deviceMemory} GB
); })}
{/* Inference Runtime with Prefill Chunking */}
Maximum Batch Size / Sequence Length with Prefill Chunking
); } // Standard Model Memory Calculator (unchanged) function StandardCalculator({ deviceMemory, modelParams, hiddenSize, numLayers, }: { deviceMemory: number; modelParams: number; hiddenSize: number; numLayers: number; }) { if (!deviceMemory || !modelParams || !hiddenSize || !numLayers) { return null; } function calculateMemory(params: number, precision: Precision) { const paramSize = { fp32: 4, fp16: 2, int8: 1, int4: 0.5 }; return params * paramSize[precision]; // in GB } function calculateMemoryPerInput(hiddenSize: number, numLayers: number) { const memoryPerInput = 4 * hiddenSize * numLayers; return memoryPerInput / 1_000_000_000; // in GB } function calculateMaxInputSize( deviceMemory: number, modelParams: number, hiddenSize: number, numLayers: number, precision: Precision, inputSize: number, ) { const memoryPerInput = calculateMemoryPerInput(hiddenSize, numLayers); const availableMemory = deviceMemory - calculateMemory(modelParams, precision); return Math.floor(availableMemory / (memoryPerInput * inputSize)); } function calculateMemoryValid( deviceMemory: number, modelParams: number, hiddenSize: number, numLayers: number, precision: Precision, batchSize: number, seqLength: number, ) { const memoryPerInput = calculateMemoryPerInput(hiddenSize, numLayers); const availableMemory = deviceMemory - calculateMemory(modelParams, precision); return availableMemory >= memoryPerInput * batchSize * seqLength; } return ( <> {/* Model Footprint */}
Model Footprint
{(['fp32', 'fp16', 'int8', 'int4'] as Precision[]).map((precision) => (
{precision.toUpperCase()}
0} />
{calculateMemory(modelParams, precision).toFixed(2)} / {deviceMemory} GB
))}
{/* Maximum Batch Size / Sequence Length */}
Maximum Batch Size / Sequence Length
); } // Main Calculator Page const Calculator = () => { const [modelParams, setModelParams] = useState(null); const [hiddenSize, setHiddenSize] = useState(null); const [numLayers, setNumLayers] = useState(null); const [intermediateSize, setIntermediateSize] = useState(null); const [deviceMemory, setDeviceMemory] = useState(null); const [isPrefillChunking, setIsPrefillChunking] = useState(false); const [modelSelectionTab, setModelSelectionTab] = useState(true); const [deviceSelectionTab, setDeviceSelectionTab] = useState(true); return (
{/* Toggle Between Standard and Prefill Chunking */}
{/* Model and Device Selection */}
Model Memory Calculator
Use our Model Memory Calculator to help you estimate the memory footprint of your model and the maximum batch size/sequence length combination you can run on your device.
{/* Model Selection */}
Model
{modelSelectionTab ? ( <> ) : ( <> setModelParams(Number(e.target.value))} /> setHiddenSize(Number(e.target.value))} /> setNumLayers(Number(e.target.value))} /> {isPrefillChunking && ( <> setIntermediateSize(Number(e.target.value))} /> )} )}
{/* Device Selection */}
Device
{deviceSelectionTab ? ( <> ) : ( <> setDeviceMemory(Number(e.target.value))} /> )}
{/* Render Appropriate Calculator Based on Toggle */} {isPrefillChunking ? ( ) : ( )}
); }; export default Calculator;