Gogryu's picture
adjustments
154d183
raw
history blame
25.6 kB
import React, { useState, useEffect, useRef } from 'react';
import * as d3 from 'd3';
import { useTheme } from '../context/themeContext';
import MODELS from '../utils/models';
import DEVICES from '../utils/devices';
type Precision = 'fp32' | 'fp16' | 'int8' | 'int4';
interface ModelSizeBarChartProps {
modelSize: number; // in GB
largestModelSize: number; // largest model in full precision (fp32)
modelPrecision: Precision;
deviceMemorySet: boolean;
activationMemorySize?: number; // optional for standard calculator
}
interface InferenceRuntimeLineChartProps {
availableMemory: AvailableMemory; // in GB
memoryPerInput: number; // in GB
}
interface LineChartData {
seqLength: number;
batchSize: number;
}
interface AvailableMemory {
int4: number;
int8: number;
fp16: number;
fp32: number;
}
// Utility to determine color based on precision
function chooseColor(precision: Precision) {
const colors = {
fp32: '#e45f5b',
fp16: '#ffc068',
int8: '#71cce9',
int4: '#383d95',
};
return colors[precision] || 'gray';
}
// Utility function to calculate total memory with precision factor for prefill chunking
function calculateTotalMemory(
modelParams: number,
hiddenSize: number,
numLayers: number,
intermediateSize: number,
precision: Precision
) {
const precisionFactor = {
fp32: 4,
fp16: 2,
int8: 1,
int4: 0.5,
};
const memoryPerInput = (4 * hiddenSize * numLayers) / 1_000_000_000; // GB
const modelMemorySize = modelParams * precisionFactor[precision]; // Adjusted by precision
const activationMemorySize = Math.max(2 * intermediateSize, 4 * hiddenSize) / 1_000_000_000; // GB
return memoryPerInput + modelMemorySize + activationMemorySize;
}
// Bar chart for model footprint (shared by both standard and prefill chunking calculators)
function ModelSizeBarChart({
modelSize,
largestModelSize,
modelPrecision,
deviceMemorySet,
activationMemorySize = 0, // default to 0 for standard calculator
}: ModelSizeBarChartProps) {
const { theme } = useTheme();
const chartRef = useRef<SVGSVGElement>(null);
const width = 600;
const height = 50;
useEffect(() => {
if (modelSize > 0 && largestModelSize > 0) {
d3.select(chartRef.current).selectAll('*').remove();
const svg = d3.select(chartRef.current).attr('width', width).attr('height', height);
const xScale = d3.scaleLinear().domain([0, largestModelSize]).range([0, width]);
if (modelSize + activationMemorySize > largestModelSize) {
svg
.append('rect')
.attr('x', 0)
.attr('y', 0)
.attr('width', width)
.attr('height', height)
.attr('fill', 'transparent')
.style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26')
.style('stroke-dasharray', '4, 4')
.style('stroke-width', '2px');
svg
.append('text')
.attr('x', width / 2)
.attr('y', height / 2)
.attr('text-anchor', 'middle')
.attr('alignment-baseline', 'middle')
.attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26')
.text('Out of Memory');
} else {
svg
.append('rect')
.attr('x', 0)
.attr('y', 0)
.attr('width', xScale(modelSize))
.attr('height', height)
.attr('fill', chooseColor(modelPrecision));
if (activationMemorySize > 0) {
svg
.append('rect')
.attr('x', xScale(modelSize))
.attr('y', 0)
.attr('width', xScale(activationMemorySize))
.attr('height', height)
.attr('fill', '#a4b8e0');
}
if (deviceMemorySet) {
svg
.append('rect')
.attr('x', xScale(modelSize + activationMemorySize))
.attr('y', 0)
.attr('width', xScale(largestModelSize - (modelSize + activationMemorySize)))
.attr('height', height)
.attr('fill', 'transparent')
.style('stroke', chooseColor(modelPrecision))
.style('stroke-width', '2px');
}
}
}
}, [modelSize, largestModelSize, modelPrecision, deviceMemorySet, activationMemorySize, theme]);
return <svg ref={chartRef}></svg>;
}
// Line chart for inference runtime (shared by both standard and prefill chunking calculators)
function InferenceRuntimeLineChart({
availableMemory,
memoryPerInput,
}: InferenceRuntimeLineChartProps) {
const { theme } = useTheme();
const chartRef = useRef(null);
const maxSeqLength = 4096;
const maxBatchSize = 128;
useEffect(() => {
if (memoryPerInput > 0 && Object.values(availableMemory).some((val) => val > 0)) {
const margin = { top: 20, right: 20, bottom: 50, left: 50 };
const width = 600 - margin.left - margin.right;
const height = 400 - margin.top - margin.bottom;
const precisions = [
{ name: 'FP32', color: '#e45f5b' },
{ name: 'FP16', color: '#ffc068' },
{ name: 'INT8', color: '#71cce9' },
{ name: 'INT4', color: '#383d95' },
];
const svg = d3.select(chartRef.current);
svg.selectAll('*').remove();
const xScale = d3.scaleLinear().domain([0, maxSeqLength]).range([0, width]);
const yScale = d3.scaleLinear().domain([0, maxBatchSize]).range([height, 0]);
const xAxis = d3.axisBottom(xScale);
const yAxis = d3.axisLeft(yScale);
const zoom = d3
.zoom()
.scaleExtent([0.5, 10])
.translateExtent([
[-width, -height],
[2 * width, 2 * height],
])
.on('zoom', (event) => {
const transform = event.transform;
svg.select('.x-axis').call(xAxis.scale(transform.rescaleX(xScale)));
svg.select('.y-axis').call(yAxis.scale(transform.rescaleY(yScale)));
svg.selectAll('path').attr('transform', transform);
});
svg
.attr('width', width + margin.left + margin.right)
.attr('height', height + margin.top + margin.bottom)
.append('g')
.attr('transform', `translate(${margin.left}, ${margin.top})`)
.call(zoom);
svg
.append('g')
.attr('class', 'x-axis')
.attr('transform', `translate(${margin.left}, ${height + margin.top})`)
.call(xAxis);
svg
.append('g')
.attr('class', 'y-axis')
.attr('transform', `translate(${margin.left}, ${margin.top})`)
.call(yAxis);
svg
.append('text')
.attr('transform', `translate(${width / 2 + margin.left}, ${height + margin.top + 40})`)
.style('text-anchor', 'middle')
.attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26')
.text('Sequence Length');
svg
.append('text')
.attr('transform', `rotate(-90)`)
.attr('y', 0)
.attr('x', 0 - height / 2 - margin.top)
.attr('dy', '1em')
.style('text-anchor', 'middle')
.attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26')
.text('Batch Size');
// Adding legend for precisions
const legend = svg
.append('g')
.attr('class', 'legend')
.attr('transform', `translate(${width - 20}, 20)`);
precisions.forEach((precision, index) => {
const legendItem = legend.append('g').attr('transform', `translate(0, ${index * 30})`);
legendItem
.append('rect')
.attr('x', 10)
.attr('y', 10)
.attr('width', 10)
.attr('height', 10)
.style('fill', precision.color);
legendItem
.append('text')
.attr('x', 30)
.attr('y', 16)
.text(precision.name)
.style('font-size', '16px')
.style('fill', theme === 'dark' ? '#f9fafb' : '#181f26')
.attr('alignment-baseline', 'middle');
});
legend
.append('rect')
.attr('class', 'legend-box')
.attr('width', 80)
.attr('height', precisions.length * 30)
.style('fill', 'none')
.style('stroke-width', '1px')
.style('stroke', theme === 'dark' ? '#f9fafb' : '#181f26');
const tooltip = d3.select('#tooltip');
for (const [precision, memory] of Object.entries(availableMemory)) {
const sequenceLengths = d3
.range(1, maxSeqLength, 1)
.map((seqLength) => ({
seqLength,
batchSize: memory / (seqLength * memoryPerInput),
}))
.filter((d) => d.batchSize <= maxBatchSize && d.batchSize > 1 && d.seqLength > 1);
const lineGroup = svg
.append('g')
.attr('transform', `translate(${margin.left}, ${margin.top})`);
const line = d3
.line<LineChartData>()
.x((d) => xScale(d.seqLength))
.y((d) => yScale(d.batchSize))
.curve(d3.curveBasis);
lineGroup
.append('path')
.datum(sequenceLengths)
.attr('fill', 'none')
.attr('stroke', chooseColor(precision as Precision))
.attr('stroke-width', 4)
.attr('d', line)
.on('mouseover', () => {
tooltip.style('opacity', 1);
tooltip.style('background-color', theme === 'dark' ? '#181f26' : '#f9fafb');
})
.on('mousemove', (event) => {
tooltip.selectAll('text').remove();
const [x, y] = d3.pointer(event);
const xValue = xScale.invert(x);
const yValue = yScale.invert(y);
tooltip
.html(`Sequence Length: ${xValue.toFixed(0)}<br/>Batch Size: ${yValue.toFixed(0)}`)
.attr('fill', theme === 'dark' ? '#f9fafb' : '#181f26')
.style('left', event.pageX + 10 + 'px')
.style('top', event.pageY + 10 + 'px');
})
.on('mouseout', () => {
tooltip.style('opacity', 0);
});
}
}
}, [availableMemory, memoryPerInput, theme]);
return (
<>
<div id='tooltip'></div>
<svg ref={chartRef} width={600} height={400} />
</>
);
}
// Prefill Chunking Calculator with Updated Logic and Precision Adjustment
function PrefillChunkingCalculator({
deviceMemory,
modelParams,
hiddenSize,
numLayers,
intermediateSize,
}: {
deviceMemory: number;
modelParams: number;
hiddenSize: number;
numLayers: number;
intermediateSize: number;
}) {
if (!deviceMemory || !modelParams || !hiddenSize || !numLayers || !intermediateSize) {
return null;
}
return (
<>
{/* Model Footprint with Prefill Chunking */}
<div className='chart'>
<div className='text-2xl text-center mb-4'>Model Footprint with Prefill Chunking</div>
<div className='space-y-8'>
{(['fp32', 'fp16', 'int8', 'int4'] as Precision[]).map((precision) => {
const totalMemory = calculateTotalMemory(
modelParams,
hiddenSize,
numLayers,
intermediateSize,
precision
);
return (
<div key={precision} className='chart-row'>
<div className='chart-row-title'>{precision.toUpperCase()}</div>
<ModelSizeBarChart
modelSize={totalMemory}
largestModelSize={deviceMemory}
modelPrecision={precision}
deviceMemorySet={deviceMemory > 0}
activationMemorySize={
Math.max(2 * intermediateSize, 4 * hiddenSize) / 1_000_000_000
}
/>
<div className='chart-row-size ml-8'>
{totalMemory.toFixed(2)} / {deviceMemory} GB
</div>
</div>
);
})}
</div>
</div>
{/* Inference Runtime with Prefill Chunking */}
<div className='chart'>
<div className='text-2xl text-center mb-4'>
Maximum Batch Size / Sequence Length with Prefill Chunking
</div>
<InferenceRuntimeLineChart
availableMemory={{
int4: deviceMemory - calculateTotalMemory(modelParams, hiddenSize, numLayers, intermediateSize, 'int4'),
int8: deviceMemory - calculateTotalMemory(modelParams, hiddenSize, numLayers, intermediateSize, 'int8'),
fp16: deviceMemory - calculateTotalMemory(modelParams, hiddenSize, numLayers, intermediateSize, 'fp16'),
fp32: deviceMemory - calculateTotalMemory(modelParams, hiddenSize, numLayers, intermediateSize, 'fp32'),
}}
memoryPerInput={(4 * hiddenSize * numLayers) / 1_000_000_000}
/>
</div>
</>
);
}
// Standard Model Memory Calculator (unchanged)
function StandardCalculator({
deviceMemory,
modelParams,
hiddenSize,
numLayers,
}: {
deviceMemory: number;
modelParams: number;
hiddenSize: number;
numLayers: number;
}) {
if (!deviceMemory || !modelParams || !hiddenSize || !numLayers) {
return null;
}
function calculateMemory(params: number, precision: Precision) {
const paramSize = { fp32: 4, fp16: 2, int8: 1, int4: 0.5 };
return params * paramSize[precision]; // in GB
}
function calculateMemoryPerInput(hiddenSize: number, numLayers: number) {
const memoryPerInput = 4 * hiddenSize * numLayers;
return memoryPerInput / 1_000_000_000; // in GB
}
function calculateMaxInputSize(
deviceMemory: number,
modelParams: number,
hiddenSize: number,
numLayers: number,
precision: Precision,
inputSize: number,
) {
const memoryPerInput = calculateMemoryPerInput(hiddenSize, numLayers);
const availableMemory = deviceMemory - calculateMemory(modelParams, precision);
return Math.floor(availableMemory / (memoryPerInput * inputSize));
}
function calculateMemoryValid(
deviceMemory: number,
modelParams: number,
hiddenSize: number,
numLayers: number,
precision: Precision,
batchSize: number,
seqLength: number,
) {
const memoryPerInput = calculateMemoryPerInput(hiddenSize, numLayers);
const availableMemory = deviceMemory - calculateMemory(modelParams, precision);
return availableMemory >= memoryPerInput * batchSize * seqLength;
}
return (
<>
{/* Model Footprint */}
<div className='chart mb-8'>
<div className='text-2xl text-center mb-4'>Model Footprint</div>
<div className='space-y-8'>
{(['fp32', 'fp16', 'int8', 'int4'] as Precision[]).map((precision) => (
<div key={precision} className='chart-row'>
<div className='chart-row-title'>{precision.toUpperCase()}</div>
<ModelSizeBarChart
modelSize={calculateMemory(modelParams, precision)}
largestModelSize={deviceMemory}
modelPrecision={precision}
deviceMemorySet={deviceMemory > 0}
/>
<div className='chart-row-size ml-8'>
{calculateMemory(modelParams, precision).toFixed(2)} / {deviceMemory} GB
</div>
</div>
))}
</div>
</div>
{/* Maximum Batch Size / Sequence Length */}
<div className='chart'>
<div className='text-2xl text-center mb-4'>
Maximum Batch Size / Sequence Length
</div>
<InferenceRuntimeLineChart
availableMemory={{
int4: deviceMemory - calculateMemory(modelParams, 'int4'),
int8: deviceMemory - calculateMemory(modelParams, 'int8'),
fp16: deviceMemory - calculateMemory(modelParams, 'fp16'),
fp32: deviceMemory - calculateMemory(modelParams, 'fp32'),
}}
memoryPerInput={calculateMemoryPerInput(hiddenSize, numLayers)}
/>
</div>
</>
);
}
// Main Calculator Page
const Calculator = () => {
const [modelParams, setModelParams] = useState<number | null>(null);
const [hiddenSize, setHiddenSize] = useState<number | null>(null);
const [numLayers, setNumLayers] = useState<number | null>(null);
const [intermediateSize, setIntermediateSize] = useState<number | null>(null);
const [deviceMemory, setDeviceMemory] = useState<number | null>(null);
const [isPrefillChunking, setIsPrefillChunking] = useState<boolean>(false);
const [modelSelectionTab, setModelSelectionTab] = useState<boolean>(true);
const [deviceSelectionTab, setDeviceSelectionTab] = useState<boolean>(true);
return (
<div className='flex flex-col items-center justify-center min-h-screen px-4'>
{/* Toggle Between Standard and Prefill Chunking */}
<div className='mb-4 flex space-x-4'>
<button
className={`${!isPrefillChunking ? 'calculator-input-tab-active' : 'calculator-input-tab'}`}
onClick={() => setIsPrefillChunking(false)}
>
Standard Calculator
</button>
<button
className={`${isPrefillChunking ? 'calculator-input-tab-active' : 'calculator-input-tab'}`}
onClick={() => setIsPrefillChunking(true)}
>
Calculator with Prefill Chunking
</button>
</div>
{/* Model and Device Selection */}
<div className='w-full max-w-4xl'>
<div className='text-4xl mb-4 text-center'>Model Memory Calculator</div>
<div className='mb-6 text-center'>
Use our Model Memory Calculator to help you estimate the memory footprint of your model
and the maximum batch size/sequence length combination you can run on your device.
</div>
<div className='grid grid-cols-1 sm:grid-cols-2 gap-4 mb-6'>
{/* Model Selection */}
<div className='calculator-input-box'>
<div className='text-2xl calculator-input-title'>Model</div>
<div className='calculator-input-content'>
<div className='mb-2'>
<button
className={`${modelSelectionTab ? 'calculator-input-tab-active' : 'calculator-input-tab'}`}
onClick={() => setModelSelectionTab(true)}
>
Model Selection
</button>
<button
className={`${modelSelectionTab ? 'calculator-input-tab' : 'calculator-input-tab-active'}`}
onClick={() => setModelSelectionTab(false)}
>
Custom Model
</button>
</div>
<div>
{modelSelectionTab ? (
<>
<label htmlFor='model'>Select a Model</label>
<select
id='model'
className='calculator-select'
onChange={(e) => {
const selectedModel = MODELS.find(
(model) => model.params === Number(e.target.value),
);
if (selectedModel) {
setModelParams(selectedModel.params);
setHiddenSize(selectedModel.hidden_size);
setNumLayers(selectedModel.num_hidden_layers);
setIntermediateSize(selectedModel.intermediate_size);
}
}}
>
<option value=''>None selected</option>
{MODELS.map((model) => (
<option
key={model.name}
value={model.params}
>
{model.name}
</option>
))}
</select>
</>
) : (
<>
<label htmlFor='modelParams'>Model Parameters (in billions)</label>
<input
type='number'
id='modelParams'
className='calculator-input mb-2'
placeholder='e.g. 7 (for LLaMA-7B)'
value={modelParams || ''}
min={0}
onChange={(e) => setModelParams(Number(e.target.value))}
/>
<label htmlFor='hiddenSize'>Hidden Size</label>
<input
type='number'
id='hiddenSize'
className='calculator-input mb-2'
placeholder='e.g. 4096 (for LLaMA-7B)'
value={hiddenSize || ''}
min={1}
onChange={(e) => setHiddenSize(Number(e.target.value))}
/>
<label htmlFor='numLayers'>Number of Layers</label>
<input
type='number'
id='numLayers'
className='calculator-input'
placeholder='e.g. 32 (for LLaMA-7B)'
value={numLayers || ''}
min={1}
onChange={(e) => setNumLayers(Number(e.target.value))}
/>
{isPrefillChunking && (
<>
<label htmlFor='intermediateSize'>Intermediate Size</label>
<input
type='number'
id='intermediateSize'
className='calculator-input'
placeholder='e.g. 11008 (for LLaMA-7B)'
value={intermediateSize || ''}
min={1}
onChange={(e) => setIntermediateSize(Number(e.target.value))}
/>
</>
)}
</>
)}
</div>
</div>
</div>
{/* Device Selection */}
<div className='calculator-input-box'>
<div className='text-2xl calculator-input-title'>Device</div>
<div className='calculator-input-content'>
<div className='mb-2'>
<button
className={`${deviceSelectionTab ? 'calculator-input-tab-active' : 'calculator-input-tab'}`}
onClick={() => {
setDeviceSelectionTab(true);
setDeviceMemory(null);
}}
>
Device Selection
</button>
<button
className={`${deviceSelectionTab ? 'calculator-input-tab' : 'calculator-input-tab-active'}`}
onClick={() => {
setDeviceSelectionTab(false);
setDeviceMemory(null);
}}
>
Custom Device
</button>
</div>
<div>
{deviceSelectionTab ? (
<>
<label htmlFor='device'>Select a Device</label>
<select
id='device'
className='calculator-select'
onChange={(e) => setDeviceMemory(Number(e.target.value))}
>
<option value=''>None selected</option>
{DEVICES.map((device) => (
<option key={device.name} value={device.size}>
{device.name}
</option>
))}
</select>
</>
) : (
<>
<label htmlFor='deviceMemory'>Device RAM (in GB)</label>
<input
type='number'
id='deviceMemory'
className='calculator-input'
placeholder='e.g. 24'
value={deviceMemory || ''}
min={0}
onChange={(e) => setDeviceMemory(Number(e.target.value))}
/>
</>
)}
</div>
</div>
</div>
</div>
{/* Render Appropriate Calculator Based on Toggle */}
{isPrefillChunking ? (
<PrefillChunkingCalculator
deviceMemory={deviceMemory!}
modelParams={modelParams!}
hiddenSize={hiddenSize!}
numLayers={numLayers!}
intermediateSize={intermediateSize!}
/>
) : (
<StandardCalculator
deviceMemory={deviceMemory!}
modelParams={modelParams!}
hiddenSize={hiddenSize!}
numLayers={numLayers!}
/>
)}
</div>
</div>
);
};
export default Calculator;