thomwolf's picture
thomwolf HF staff
update
f2c15d5
raw
history blame
12.5 kB
import * as d3 from 'd3';
export function activationMemory(
a, // attention heads
b, // micro batch size
h, // hidden dimension size
h_ff, // feedforward dimension size (often h_ff = 4h)
L, // number of layers
s, // sequence length
v, // vocab size
mixed = true,
recomputation = "none",
ff_activation = "relu"
) {
console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, mixed, recomputation, ff_activation });
// https://arxiv.org/pdf/2205.05198
const bytesPerValue = mixed ? 2 : 4;
const oneLayerAttention = s * b * h * (bytesPerValue * 5 + 1) + ((2 * bytesPerValue + 1) * a * s * s * b); // eq (2)
let oneLayerFeedforward;
if (ff_activation === "relu") {
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue) // inputs of 1st/2nd linear layers
+ s * b * h); // dropout
} else if (ff_activation === "gelu") {
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue) // inputs of 1st/2nd linear layers
+ s * b * h_ff * bytesPerValue // inputs of activation function (not really necessary for Relu)
+ s * b * h); // dropout
} else if (ff_activation === "swiglu") {
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue) // inputs of input/output linear layers
+ s * b * h_ff * bytesPerValue * 3 // inputs of activation function
+ s * b * h); // dropout (note that dropout is lower-precision - boolean)
}
const layerNorm = s * b * h * bytesPerValue;
const inputDropout = s * b * h; // section 4.3
const outputLayerNorm = s * b * h * bytesPerValue;
const outputLayerProjection = s * b * h * bytesPerValue;
const outputCrossEntropy = s * b * v * 4; // In FP32
let oneLayer;
if (recomputation === "none") {
oneLayer = oneLayerAttention + oneLayerFeedforward + 2 * layerNorm; // eq (2)
} else if (recomputation === "selective") {
oneLayer = s * b * h * 34; // eq (6)
} else if (recomputation === "full") {
oneLayer = s * b * h * 2;
} else {
throw new Error("Invalid recomputation value");
}
const data = {
name: "activationMemory",
children: [
...Array.from({ length: L }, (_, index) => ({
name: `Layer ${index + 1}`,
children: [
{ name: 'Attention', value: oneLayerAttention },
{ name: 'Feedforward', value: oneLayerFeedforward },
{ name: 'LayerNorm', value: 2 * layerNorm },
]
})),
{ name: 'Dropout', value: inputDropout },
{ name: 'LayerNorm', value: outputLayerNorm },
{ name: 'Projection', value: outputLayerProjection },
{ name: 'Cross Entropy', value: outputCrossEntropy }
]
};
const total = L * oneLayer + inputDropout + outputLayerNorm + outputLayerProjection + outputCrossEntropy;
return data;
}
export function paramGradsOpt(h, L, s, v, k = 8, mixed = true) {
console.log('paramGradsOpt called with:', { h, L, s, v, k, mixed });
const emb = h * (v + s);
const oneLayer = 12 * h ** 2 + 13 * h;
const other = 2 * h;
const n = emb + L * oneLayer + other;
if (mixed) {
k += 4;
}
const bytesPerParameter = mixed ? 2 : 4;
const result = [bytesPerParameter * n, bytesPerParameter * n, k * n];
console.log('paramGradsOpt result:', result);
return result;
}
export function updateGraph() {
console.log('updateGraph called');
const a = +document.getElementById('a').value;
const b = +document.getElementById('b').value;
const h = +document.getElementById('h').value;
const h_ff = +document.getElementById('h_ff').value;
const L = +document.getElementById('L').value;
const s = +document.getElementById('s').value;
const v = +document.getElementById('v').value;
const mixed = document.getElementById('mixed').checked;
const recomputation = document.getElementById('recomputation').value;
const ff_activation = document.getElementById('ff_activation').value;
console.log('Slider values:', { a, b, h, h_ff, L, s, v, mixed, recomputation, ff_activation });
const fixedSize100GB = 100 * 1024 * 1024 * 1024; // 100GB in bytes
const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, mixed, recomputation, ff_activation);
const paramGradsOptValue = paramGradsOpt(h, L, s, v)[0];
const data = {
name: "root",
children: [
{
name: 'Total',
value: 0,
children: [
activationMemoryData,
{ name: 'paramGradsOpt', value: paramGradsOptValue }
]
}
]
};
console.log('Data for treemap:', data);
const width = 700;
const height = 450;
const legendHeight = 50;
const svg = d3.select("#graph").select("svg");
svg.selectAll("*").remove();
svg.attr("width", width)
.attr("height", height + legendHeight);
const treemap = d3.treemap()
.size([width, height])
.paddingOuter(3)
.paddingTop(19)
.paddingInner(1)
.round(true);
const root = d3.hierarchy(data)
.sum(d => d.value);
// .sort((a, b) => b.value - a.value);
if (root.children[0].value < fixedSize100GB) {
root.children[0].value = fixedSize100GB;
}
console.log('Treemap root:', root);
treemap(root);
const color = d => {
switch(d.data.name) {
case 'paramGradsOpt': return '#4e79a7'; // Blue
case 'activationMemory': return '#f28e2c'; // Orange
case 'fixed100GB': return '#59a14f'; // Green
case 'Attention': return '#e15759'; // Red
case 'Feedforward': return '#f28e2c'; // Orange
case 'LayerNorm': return '#9b59b6'; // Purple
case 'Dropout': return '#e15759'; // Red
case 'Projection': return '#f28e2c'; // Orange
case 'Cross Entropy': return '#e15759'; // Red
default: return '#59a14f'; // Red (for unexpected cases)
}
};
const cell = svg.selectAll("g")
.data(root.descendants())
.join("g")
.attr("transform", d => `translate(${d.x0},${d.y0})`);
cell.append("rect")
.attr("width", d => d.x1 - d.x0)
.attr("height", d => d.y1 - d.y0)
.attr("fill", d => d.depth === 1 ? "none" : color(d))
.attr("stroke", d => d.depth === 1 ? color(d) : "none")
.attr("stroke-width", 2);
const fontSize = 10;
const padding = 2;
cell.append("text")
.attr("font-size", `${fontSize}px`)
.attr("font-family", "sans-serif")
.each(function(d) {
if (d.depth === 0) return; // Skip root node
const node = d3.select(this);
const name = d.data.name;
const value = formatBytes(d.value);
if (d.depth === 1) {
// Parent node (fixed100GB)
node.attr("transform", `translate(${padding},${fontSize + padding})`)
.attr("font-weight", "bold")
.text(`${name}: ${value}`);
} else {
// Child nodes
node.attr("transform", `translate(${padding},${fontSize + padding})`)
.text(name[0].toUpperCase()) // Display only the first letter
.append("title") // Add title for hover effect
.text(`${name}: ${value}`);
}
});
// Add invisible rect for better hover area
cell.append("rect")
.attr("width", d => d.x1 - d.x0)
.attr("height", d => d.y1 - d.y0)
.attr("fill", "none")
.attr("pointer-events", "all")
.append("title")
.text(d => `${d.data.name}: ${formatBytes(d.value)}`);
// Adjust legend positioning
const legendData = root.children[0].children.concat(root.children[0]);
const legend = svg.append("g")
.attr("font-family", "sans-serif")
.attr("font-size", 10)
.attr("text-anchor", "start")
.attr("transform", `translate(0, ${height})`)
.selectAll("g")
.data(legendData)
.join("g")
.attr("transform", (d, i) => `translate(${i * 240}, 0)`);
legend.append("rect")
.attr("x", 0)
.attr("width", 19)
.attr("height", 19)
.attr("fill", d => d.data.name === 'fixed100GB' ? 'none' : color(d))
.attr("stroke", d => d.data.name === 'fixed100GB' ? color(d) : 'none')
.attr("stroke-width", 2);
legend.append("text")
.attr("x", 24)
.attr("y", 9.5)
.attr("dy", "0.32em")
.text(d => `${d.data.name}: ${formatBytes(d.value)}`);
console.log('Treemap nodes created');
}
function formatBytes(bytes) {
const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB'];
if (bytes === 0) return '0 Bytes';
const i = parseInt(Math.floor(Math.log(bytes) / Math.log(1024)), 10);
return `${(bytes / (1024 ** i)).toFixed(2)} ${sizes[i]}`;
}
const presets = {
"Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, mixed: true, recomputation: "none", ff_activation: "gelu" },
"8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, mixed: true, recomputation: "none", ff_activation: "swiglu" },
"70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, mixed: true, recomputation: "none", ff_activation: "swiglu" },
"405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, mixed: true, recomputation: "none", ff_activation: "swiglu" }
};
function setPresetValues(preset) {
if (preset === "custom") return;
const values = presets[preset];
Object.keys(values).forEach(key => {
const element = document.getElementById(key);
const inputElement = document.getElementById(`${key}_input`);
if (element) {
if (element.type === 'checkbox') {
element.checked = values[key];
} else {
element.value = values[key];
}
}
if (inputElement) {
inputElement.value = values[key];
}
});
updateGraph();
}
function syncSliderAndInput(sliderId, inputId) {
const slider = document.getElementById(sliderId);
const input = document.getElementById(inputId);
slider.addEventListener('input', () => {
input.value = slider.value;
updateGraph();
});
input.addEventListener('input', () => {
let value = parseInt(input.value);
if (isNaN(value)) {
value = parseInt(slider.min);
}
value = Math.max(parseInt(slider.min), Math.min(parseInt(slider.max), value));
slider.value = value;
input.value = value;
updateGraph();
});
}
export const init_memory_plot = function () {
console.log('DOM fully loaded and parsed');
const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v']; // Added 'v'
sliderIds.forEach(id => {
syncSliderAndInput(id, `${id}_input`);
});
const recomputationSelect = document.getElementById('recomputation');
recomputationSelect.addEventListener('change', updateGraph);
const ffActivationSelect = document.getElementById('ff_activation');
ffActivationSelect.addEventListener('change', updateGraph);
const mixedCheckbox = document.getElementById('mixed');
mixedCheckbox.addEventListener('change', updateGraph);
const presetSelect = document.getElementById('presets');
presetSelect.addEventListener('change', (event) => {
setPresetValues(event.target.value);
});
// Set max values for sliders based on the highest values in the presets
document.getElementById('a').max = 128;
document.getElementById('b').max = 53248;
document.getElementById('h').max = 16384;
document.getElementById('h_ff').max = 65536;
document.getElementById('L').max = 126;
document.getElementById('s').max = 128000;
document.getElementById('v').max = 100000; // Set a reasonable max for vocabulary size
console.log('Adding svg');
const svg = d3.select("#graph")
.append("svg")
.attr("width", 960)
.attr("height", 500);
updateGraph();
};