import React, { useState, useEffect, useRef } from 'react' import * as d3 from 'd3' import { useTheme } from '../context/themeContext' import MODELS from '../utils/models' import DEVICES from '../utils/devices' type Precision = 'fp32' | 'fp16' | 'int8' | 'int4' interface ModelSizeBarChartProps { modelSize: number // in GB largestModelSize: number // largest model in full precision (fp32) modelPrecision: Precision // enum of fp32, fp16, int8, int4 deviceMemorySet: boolean // true if device memory is set } interface InferenceRuntimeLineChartProps { availableMemory: AvailableMemory // in GB memoryPerInput: number // in GB } interface LineChartData { seqLength: number batchSize: number } interface AvailableMemory { int4: number int8: number fp16: number fp32: number } // Table containing the mapping of backends to precisions const BackendPrecisionTable = () => { return (
Backend GPU CPU Accuracy
fast 16 16 ⭐⭐⭐
compress-fast 4 8 ⭐⭐
compress 4 4
baseline 16 16 ⭐⭐⭐
) } // Bar chart for model footprint (standard version) function ModelSizeBarChart({ modelSize, largestModelSize, modelPrecision, deviceMemorySet, }: ModelSizeBarChartProps) { const { theme } = useTheme() const chartRef = useRef(null) const width = 600 const height = 50 useEffect(() => { d3.select(chartRef.current).selectAll('*').remove() const svg = d3.select(chartRef.current).attr('width', width).attr('height', height) const xScale = d3.scaleLinear().domain([0, largestModelSize]).range([0, width]) if (modelSize > largestModelSize) { svg .append('rect') .attr('x', 0) .attr('y', 0) .attr('width', width) .attr('height', height) .attr('fill', 'transparent') .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26') .style('stroke-dasharray', '4, 4') .style('stroke-width', '2px') svg .append('text') .attr('x', width / 2) .attr('y', height / 2) .attr('text-anchor', 'middle') .attr('alignment-baseline', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Out of Memory') } else { svg .append('rect') .attr('x', 0) .attr('y', 0) .attr('width', xScale(modelSize)) .attr('height', height) .attr('fill', chooseColor(modelPrecision)) if (deviceMemorySet) { svg .append('rect') .attr('x', xScale(modelSize)) .attr('y', 0) .attr('width', xScale(largestModelSize - modelSize)) .attr('height', height) .attr('fill', 'transparent') .style('stroke', chooseColor(modelPrecision)) .style('stroke-width', '2px') } } }, [modelSize, largestModelSize, modelPrecision, deviceMemorySet, theme]) function chooseColor(precision: Precision) { const colors = { fp32: '#e45f5b', fp16: '#ffc068', int8: '#71cce9', int4: '#383d95', } return colors[precision] || 'gray' } return } // Line chart for standard inference runtime function InferenceRuntimeLineChart({ availableMemory, memoryPerInput, }: InferenceRuntimeLineChartProps) { const { theme } = useTheme() const chartRef = useRef(null) const maxSeqLength = 4096 const maxBatchSize = 128 useEffect(() => { const margin = { top: 20, right: 20, bottom: 50, left: 50 } const width = 600 - margin.left - margin.right const height = 400 - margin.top - margin.bottom const precisions = [ { name: 'FP32', color: '#e45f5b' }, { name: 'FP16', color: '#ffc068' }, { name: 'INT8', color: '#71cce9' }, { name: 'INT4', color: '#383d95' }, ] const svg = d3.select(chartRef.current) svg.selectAll('*').remove() const xScale = d3.scaleLinear().domain([0, maxSeqLength]).range([0, width]) const yScale = d3.scaleLinear().domain([0, maxBatchSize]).range([height, 0]) const xAxis = d3.axisBottom(xScale) const yAxis = d3.axisLeft(yScale) svg .attr('width', width + margin.left + margin.right) .attr('height', height + margin.top + margin.bottom) .append('g') .attr('transform', `translate(${margin.left}, ${margin.top})`) svg .append('g') .attr('transform', `translate(${margin.left}, ${height + margin.top})`) .call(xAxis) svg.append('g').attr('transform', `translate(${margin.left}, ${margin.top})`).call(yAxis) svg .append('text') .attr('transform', `translate(${width / 2 + margin.left}, ${height + margin.top + 40})`) .style('text-anchor', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Sequence Length') svg .append('text') .attr('transform', `rotate(-90)`) .attr('y', 0) .attr('x', 0 - height / 2 - margin.top) .attr('dy', '1em') .style('text-anchor', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Batch Size') const legend = svg .append('g') .attr('class', 'legend') .attr('transform', `translate(${width - 20}, 20)`) precisions.forEach((precision, index) => { const legendItem = legend.append('g').attr('transform', `translate(0, ${index * 30})`) legendItem .append('rect') .attr('x', 10) .attr('y', 10) .attr('width', 10) .attr('height', 10) .style('fill', precision.color) legendItem .append('text') .attr('x', 30) .attr('y', 16) .text(precision.name) .style('font-size', '16px') .style('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .attr('alignment-baseline', 'middle') }) legend .append('rect') .attr('class', 'legend-box') .attr('width', 80) .attr('height', precisions.length * 30) .style('fill', 'none') .style('stroke-width', '1px') .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26') const tooltip = d3.select('#tooltip') for (const [precision, memory] of Object.entries(availableMemory)) { const sequenceLengths = d3 .range(1, maxSeqLength, 1) .map((seqLength) => { return { seqLength, batchSize: memory / (seqLength * memoryPerInput) } }) .filter((seqLength) => seqLength.batchSize <= maxBatchSize) .filter((seqLength) => seqLength.batchSize > 1 && seqLength.seqLength > 1) const lineGroup = svg .append('g') .attr('transform', `translate(${margin.left}, ${margin.top})`) const line = d3 .line() .x((d) => xScale(d.seqLength)) .y((d) => yScale(d.batchSize)) .curve(d3.curveBasis) lineGroup .append('path') .datum(sequenceLengths) .attr('fill', 'none') .attr('stroke', chooseColor(precision as Precision)) .attr('stroke-width', 4) .attr('d', line) .on('mouseover', () => { tooltip.style('opacity', 1) tooltip.style('background-color', theme === 'dark' ? '#181f26' : '#f9fafb') }) .on('mousemove', (event) => { tooltip.selectAll('text').remove() const [x, y] = d3.pointer(event) const xValue = xScale.invert(x) const yValue = yScale.invert(y) tooltip .html(`Sequence Length: ${xValue.toFixed(0)}
Batch Size: ${yValue.toFixed(0)}`) .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .style('left', event.pageX + 10 + 'px') .style('top', event.pageY + 10 + 'px') }) .on('mouseout', () => { tooltip.style('opacity', 0) }) } }, [availableMemory, memoryPerInput, theme]) function chooseColor(precision: Precision) { const colors = { fp32: '#e45f5b', fp16: '#ffc068', int8: '#71cce9', int4: '#383d95', } return colors[precision] || 'gray' } return ( <>
) } // Prefill Chunking Model Size Bar Chart function PrefillChunkingModelSizeBarChart({ modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize, }: { modelSize: number // in GB largestModelSize: number // largest model in full precision (fp32) modelPrecision: Precision // enum of fp32, fp16, int8, int4 deviceMemorySet: boolean // true if device memory is set activationMemorySize: number // additional memory for activations in GB }) { const { theme } = useTheme() const chartRef = useRef(null) const width = 600 const height = 50 useEffect(() => { d3.select(chartRef.current).selectAll('*').remove() const svg = d3.select(chartRef.current).attr('width', width).attr('height', height) const xScale = d3.scaleLinear().domain([0, largestModelSize]).range([0, width]) if (modelSize + activationMemorySize > largestModelSize) { svg .append('rect') .attr('x', 0) .attr('y', 0) .attr('width', width) .attr('height', height) .attr('fill', 'transparent') .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26') .style('stroke-dasharray', '4, 4') .style('stroke-width', '2px') svg .append('text') .attr('x', width / 2) .attr('y', height / 2) .attr('text-anchor', 'middle') .attr('alignment-baseline', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Out of Memory') } else { svg .append('rect') .attr('x', 0) .attr('y', 0) .attr('width', xScale(modelSize)) .attr('height', height) .attr('fill', chooseColor(modelPrecision)) svg .append('rect') .attr('x', xScale(modelSize)) .attr('y', 0) .attr('width', xScale(activationMemorySize)) .attr('height', height) .attr('fill', '#a4b8e0') if (deviceMemorySet) { svg .append('rect') .attr('x', xScale(modelSize + activationMemorySize)) .attr('y', 0) .attr('width', xScale(largestModelSize - (modelSize + activationMemorySize))) .attr('height', height) .attr('fill', 'transparent') .style('stroke', chooseColor(modelPrecision)) .style('stroke-width', '2px') } } }, [modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize, theme]) function chooseColor(precision: Precision) { const colors = { fp32: '#e45f5b', fp16: '#ffc068', int8: '#71cce9', int4: '#383d95', } return colors[precision] || 'gray' } return } // Prefill Chunking Inference Runtime Line Chart function PrefillChunkingInferenceRuntimeLineChart({ availableMemory, memoryPerInput, activationMemorySize, }: { availableMemory: AvailableMemory // in GB memoryPerInput: number // in GB activationMemorySize: number // in GB }) { const { theme } = useTheme() const chartRef = useRef(null) const maxSeqLength = 4096 const maxBatchSize = 128 useEffect(() => { const margin = { top: 20, right: 20, bottom: 50, left: 50 } const width = 600 - margin.left - margin.right const height = 400 - margin.top - margin.bottom const precisions = [ { name: 'FP32', color: '#e45f5b' }, { name: 'FP16', color: '#ffc068' }, { name: 'INT8', color: '#71cce9' }, { name: 'INT4', color: '#383d95' }, ] const svg = d3.select(chartRef.current) svg.selectAll('*').remove() const xScale = d3.scaleLinear().domain([0, maxSeqLength]).range([0, width]) const yScale = d3.scaleLinear().domain([0, maxBatchSize]).range([height, 0]) const xAxis = d3.axisBottom(xScale) const yAxis = d3.axisLeft(yScale) svg .attr('width', width + margin.left + margin.right) .attr('height', height + margin.top + margin.bottom) .append('g') .attr('transform', `translate(${margin.left}, ${margin.top})`) svg .append('g') .attr('transform', `translate(${margin.left}, ${height + margin.top})`) .call(xAxis) svg.append('g').attr('transform', `translate(${margin.left}, ${margin.top})`).call(yAxis) svg .append('text') .attr('transform', `translate(${width / 2 + margin.left}, ${height + margin.top + 40})`) .style('text-anchor', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Sequence Length') svg .append('text') .attr('transform', `rotate(-90)`) .attr('y', 0) .attr('x', 0 - height / 2 - margin.top) .attr('dy', '1em') .style('text-anchor', 'middle') .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .text('Batch Size') const legend = svg .append('g') .attr('class', 'legend') .attr('transform', `translate(${width - 20}, 20)`) precisions.forEach((precision, index) => { const legendItem = legend.append('g').attr('transform', `translate(0, ${index * 30})`) legendItem .append('rect') .attr('x', 10) .attr('y', 10) .attr('width', 10) .attr('height', 10) .style('fill', precision.color) legendItem .append('text') .attr('x', 30) .attr('y', 16) .text(precision.name) .style('font-size', '16px') .style('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .attr('alignment-baseline', 'middle') }) legend .append('rect') .attr('class', 'legend-box') .attr('width', 80) .attr('height', precisions.length * 30) .style('fill', 'none') .style('stroke-width', '1px') .style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26') const tooltip = d3.select('#tooltip') for (const [precision, memory] of Object.entries(availableMemory)) { const sequenceLengths = d3 .range(1, maxSeqLength, 1) .map((seqLength) => { return { seqLength, batchSize: (memory - activationMemorySize) / (seqLength * memoryPerInput), } }) .filter((seqLength) => seqLength.batchSize <= maxBatchSize) .filter((seqLength) => seqLength.batchSize > 1 && seqLength.seqLength > 1) const lineGroup = svg .append('g') .attr('transform', `translate(${margin.left}, ${margin.top})`) const line = d3 .line() .x((d) => xScale(d.seqLength)) .y((d) => yScale(d.batchSize)) .curve(d3.curveBasis) lineGroup .append('path') .datum(sequenceLengths) .attr('fill', 'none') .attr('stroke', chooseColor(precision as Precision)) .attr('stroke-width', 4) .attr('d', line) .on('mouseover', () => { tooltip.style('opacity', 1) tooltip.style('background-color', theme === 'dark' ? '#181f26' : '#f9fafb') }) .on('mousemove', (event) => { tooltip.selectAll('text').remove() const [x, y] = d3.pointer(event) const xValue = xScale.invert(x) const yValue = yScale.invert(y) tooltip .html(`Sequence Length: ${xValue.toFixed(0)}
Batch Size: ${yValue.toFixed(0)}`) .attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26') .style('left', event.pageX + 10 + 'px') .style('top', event.pageY + 10 + 'px') }) .on('mouseout', () => { tooltip.style('opacity', 0) }) } }, [availableMemory, memoryPerInput, activationMemorySize, theme]) function chooseColor(precision: Precision) { const colors = { fp32: '#e45f5b', fp16: '#ffc068', int8: '#71cce9', int4: '#383d95', } return colors[precision] || 'gray' } return ( <>
) } // Prefill Chunking Calculator const PrefillChunkingCalculator = ({ deviceMemory, modelParams, hiddenSize, numLayers, batchSize, seqLength, maxChunkSize, intermediateSize, }: { deviceMemory: number modelParams: number hiddenSize: number numLayers: number batchSize: number | null seqLength: number | null maxChunkSize: number | null intermediateSize: number | null }) => { // Function to calculate memory usage based on precision and model parameters function calculateMemory(params: number, precision: 'fp32' | 'fp16' | 'int8' | 'int4'): number { const paramSize = { fp32: 4, fp16: 2, int8: 1, int4: 0.5 } return params * paramSize[precision] // in GB } function calculateMemoryPerInput(hiddenSize: number, numLayers: number) { const memoryPerInput = 4 * hiddenSize * numLayers return memoryPerInput / 1_000_000_000 // in GB } function calculateMemoryNeededForActivations( maxChunkSize: number, intermediateSize: number, hiddenSize: number, ) { const activationMemorySize = maxChunkSize * 2 * Math.max(2 * intermediateSize, 4 * hiddenSize) return activationMemorySize / 1_000_000_000 // in GB } const memoryPerInput = calculateMemoryPerInput(hiddenSize, numLayers) const activationMemorySize = calculateMemoryNeededForActivations( maxChunkSize!, intermediateSize!, hiddenSize, ) return (
Prefill Chunking Calculator
Model Footprint with Prefill Chunking
FP32
0} activationMemorySize={activationMemorySize} />
{(calculateMemory(modelParams, 'fp32') + activationMemorySize).toFixed(2)}{' '} {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
FP16
0} activationMemorySize={activationMemorySize} />
{(calculateMemory(modelParams, 'fp16') + activationMemorySize).toFixed(2)}{' '} {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
INT8
0} activationMemorySize={activationMemorySize} />
{(calculateMemory(modelParams, 'int8') + activationMemorySize).toFixed(2)}{' '} {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
INT4
0} activationMemorySize={activationMemorySize} />
{(calculateMemory(modelParams, 'int4') + activationMemorySize).toFixed(2)}{' '} {deviceMemory !== null && deviceMemory > 0 ? `/ ${deviceMemory} ` : null}GB
Maximum Batch Size / Sequence Length with Prefill Chunking
) } // Calculator page with toggle feature const Calculator = () => { // Model Parameters (billions) const [modelParams, setModelParams] = useState(null) const [hiddenSize, setHiddenSize] = useState(null) const [numLayers, setNumLayers] = useState(null) // Device Memory (GB) const [deviceMemory, setDeviceMemory] = useState(null) // Inputs const [batchSize, setBatchSize] = useState(null) const [seqLength, setSeqLength] = useState(null) // Prefill Chunking Inputs const [maxChunkSize, setMaxChunkSize] = useState(null) const [intermediateSize, setIntermediateSize] = useState(null) // Toggle between standard calculator and prefill chunking calculator const [isPrefillChunking, setIsPrefillChunking] = useState(false) // Toggle between model selection and custom model input const [modelSelectionTab, setModelSelectionTab] = useState(true) // Toggle between device selection and custom device input const [deviceSelectionTab, setDeviceSelectionTab] = useState(true) // Calculate model memory function calculateMemory(params: number, precision: Precision) { const paramSize = { fp32: 4, fp16: 2, int8: 1, int4: 0.5 } return params * paramSize[precision] // in GB } // Calculate memory per input (sequence length and batch size) function calculateMemoryPerInput(hiddenSize: number, numLayers: number) { const memoryPerInput = 4 * hiddenSize * numLayers return memoryPerInput / 1_000_000_000 // in GB } // Calculate maximum batch size / sequence length function calculateMaxInputSize( deviceMemory: number, modelParams: number, hiddenSize: number, numLayers: number, precision: Precision, inputSize: number, ) { const memoryPerInput = calculateMemoryPerInput(hiddenSize, numLayers) const availableMemory = deviceMemory - calculateMemory(modelParams, precision) return Math.floor(availableMemory / (memoryPerInput * inputSize)) } // Check if memory is valid for batch size / sequence length combination function calculateMemoryValid( deviceMemory: number, modelParams: number, hiddenSize: number, numLayers: number, precision: Precision, batchSize: number, seqLength: number, ) { const memoryPerInput = calculateMemoryPerInput(hiddenSize, numLayers) const availableMemory = deviceMemory - calculateMemory(modelParams, precision) return availableMemory >= memoryPerInput * batchSize * seqLength } return (
{/* Toggle Button */}
{/* Model Memory Calculator */}
Model Memory Calculator
Use our Model Memory Calculator to help you estimate the memory footprint of your model for different precisions and the maximum batch size / sequence length combination you can run on your device.
{/* Model and Device Selection */}
{/* Model Selection */}
Model
{modelSelectionTab ? ( <> ) : ( <> setModelParams(Number(e.target.value))} /> setHiddenSize(Number(e.target.value))} /> setNumLayers(Number(e.target.value))} /> )}
{/* Device Selection */}
Device
{deviceSelectionTab ? ( <> ) : ( <> setDeviceMemory(Number(e.target.value))} /> )}
Backend Precision Table
This table shows the precision used by each Takeoff backend for CPUs and GPUs, as well as their accuracy preservation.
Input parameters
Sequence Length: The combined length of input tokens and output tokens. To restrict the maximum sequence length for inference on Takeoff, use the API parameters{' '} prompt_new_tokens for input tokens and max_new_tokens for output tokens when making a request.
Batch Size: The number of sequences that can be processed in parallel. To set a maximum batch size for inference on Takeoff, set the environment variable{' '} TAKEOFF_MAX_BATCH_SIZE to your desired value.
{/* Prefill Chunking Settings */} {isPrefillChunking && (
Prefill Chunking Settings
setMaxChunkSize(Number(e.target.value))} /> setIntermediateSize(Number(e.target.value))} />
)} {/* Charts Section */} {isPrefillChunking ? ( ) : ( hiddenSize && numLayers && deviceMemory && modelParams && ( <> {/* Model Footprint Chart */}
Model Footprint
FP32
0} />
{calculateMemory(modelParams, 'fp32')} {deviceMemory ? `/ ${deviceMemory} ` : null}GB
{/* FP16 */}
FP16
0} />
{calculateMemory(modelParams, 'fp16')} {deviceMemory ? `/ ${deviceMemory} ` : null}GB
{/* INT8 */}
INT8
0} />
{calculateMemory(modelParams, 'int8')} {deviceMemory ? `/ ${deviceMemory} ` : null}GB
{/* INT4 */}
INT4
0} />
{calculateMemory(modelParams, 'int4')} {deviceMemory ? `/ ${deviceMemory} ` : null}GB
{/* Maximum Batch Size / Sequence Length Chart */}
Maximum Batch Size / Sequence Length
{/* Batch Size and Sequence Length Inputs */}
) )}
) } export default Calculator