|
import * as d3 from 'd3'; |
|
|
|
export function activationMemory( |
|
a, |
|
b, |
|
h, |
|
h_ff, |
|
L, |
|
s, |
|
v, |
|
tp = 1, |
|
mixed = true, |
|
recomputation = "none", |
|
ff_activation = "relu", |
|
seq_parallel = false |
|
) { |
|
console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel }); |
|
|
|
const bytesPerValue = mixed ? 2 : 4; |
|
|
|
let oneLayerAttention; |
|
if (recomputation === "none" || recomputation === "full") { |
|
if (seq_parallel) { |
|
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1) + ((2 * bytesPerValue + 1) * a * s * s * b); |
|
} else { |
|
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1) + ((2 * bytesPerValue + 1) * a * s * s * b / tp); |
|
} |
|
} else if (recomputation === "selective") { |
|
if (seq_parallel) { |
|
oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1); |
|
} else { |
|
oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1); |
|
} |
|
} else { |
|
throw new Error("Invalid recomputation value"); |
|
} |
|
|
|
let oneLayerFeedforward; |
|
if (ff_activation === "relu") { |
|
if (seq_parallel) { |
|
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) |
|
+ s * b * h / tp); |
|
} else { |
|
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) |
|
+ s * b * h); |
|
} |
|
} else if (ff_activation === "gelu") { |
|
if (seq_parallel) { |
|
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) |
|
+ s * b * h_ff * bytesPerValue / tp |
|
+ s * b * h / tp); |
|
} else { |
|
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) |
|
+ s * b * h_ff * bytesPerValue / tp |
|
+ s * b * h); |
|
} |
|
} else if (ff_activation === "swiglu") { |
|
if (seq_parallel) { |
|
oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) |
|
+ s * b * h_ff * bytesPerValue * 3 / tp |
|
+ s * b * h / tp); |
|
} else { |
|
oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) |
|
+ s * b * h_ff * bytesPerValue * 3 / tp |
|
+ s * b * h); |
|
} |
|
} |
|
|
|
let layerNorm; |
|
if (seq_parallel) { |
|
layerNorm = s * b * h * bytesPerValue / tp; |
|
} else { |
|
layerNorm = s * b * h * bytesPerValue; |
|
} |
|
|
|
const inputDropout = seq_parallel ? s * b * h / tp : s * b * h; |
|
const outputLayerNorm = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue; |
|
const outputLayerProjection = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue; |
|
const outputCrossEntropy = seq_parallel ? s * b * v * 4 / tp : s * b * v * 4; |
|
|
|
|
|
let data |
|
if (recomputation === "none" || recomputation === "selective") { |
|
|
|
data = { |
|
name: "Activation Memory", |
|
children: [ |
|
...Array.from({ length: L }, (_, index) => ({ |
|
name: `Layer ${index + 1}`, |
|
children: [ |
|
{ name: 'Attention', value: oneLayerAttention }, |
|
{ name: 'Feedforward', value: oneLayerFeedforward }, |
|
{ name: 'LayerNorm', value: 2 * layerNorm }, |
|
] |
|
})), |
|
{ name: 'Dropout', value: inputDropout }, |
|
{ name: 'LayerNorm', value: outputLayerNorm }, |
|
{ name: 'Projection', value: outputLayerProjection }, |
|
{ name: 'Cross Entropy', value: outputCrossEntropy } |
|
] |
|
}; |
|
} else if (recomputation === "full") { |
|
data = { |
|
name: "Activation Memory", |
|
children: [ |
|
{ name: 'LayerInput', value: s * b * h * bytesPerValue * L }, |
|
{ name: 'Dropout', value: inputDropout }, |
|
{ name: 'LayerNorm', value: outputLayerNorm }, |
|
{ name: 'Projection', value: outputLayerProjection }, |
|
{ name: 'Cross Entropy', value: outputCrossEntropy } |
|
] |
|
}; |
|
} else { |
|
throw new Error("Invalid recomputation value"); |
|
} |
|
|
|
return data; |
|
} |
|
|
|
export function paramGradsOpt(h, L, s, v, k = 8, dp = 1, zero = 0, mixed = true) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
console.log('paramGradsOpt called with:', { h, L, s, v, k, dp, zero, mixed }); |
|
const emb = h * (v + s); |
|
const oneLayer = 12 * h ** 2 + 13 * h; |
|
const other = 2 * h; |
|
|
|
const n = emb + L * oneLayer + other; |
|
|
|
if (mixed) { |
|
k += 4; |
|
} |
|
const bytesPerParameter = mixed ? 2 : 4; |
|
|
|
const data = { |
|
name: "Parameters / Gradients / Optimizer States", |
|
children: [ |
|
{ name: 'Parameters', value: zero >= 3 ? bytesPerParameter * n / dp : bytesPerParameter * n }, |
|
{ name: 'Gradients', value: zero >= 2 ? bytesPerParameter * n / dp : bytesPerParameter * n }, |
|
{ name: 'OptimizerAverages', value: zero >= 1 ? k * n / dp : k * n } |
|
] |
|
}; |
|
console.log('paramGradsOpt result:', data); |
|
return data; |
|
} |
|
|
|
export function updateGraph() { |
|
console.log('updateGraph called'); |
|
const a = +document.getElementById('a').value; |
|
const b = +document.getElementById('b').value; |
|
const h = +document.getElementById('h').value; |
|
const h_ff = +document.getElementById('h_ff').value; |
|
const L = +document.getElementById('L').value; |
|
const s = +document.getElementById('s').value; |
|
const v = +document.getElementById('v').value; |
|
const k = +document.getElementById('k').value; |
|
const tp = +document.getElementById('tp').value; |
|
const zero = document.getElementById('zero').value; |
|
const dp = document.getElementById('dp').value; |
|
const mixed = document.getElementById('mixed').checked; |
|
const recomputation = document.getElementById('recomputation').value; |
|
const ff_activation = document.getElementById('ff_activation').value; |
|
const seq_parallel = document.getElementById('seq_parallel').checked; |
|
|
|
console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation, seq_parallel }); |
|
|
|
const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel); |
|
const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed); |
|
|
|
const data = { |
|
name: "root", |
|
children: [ |
|
{ |
|
name: 'Total', |
|
value: 0, |
|
children: [ |
|
activationMemoryData, |
|
paramGradsOptValue |
|
] |
|
} |
|
] |
|
}; |
|
|
|
console.log('Data for treemap:', data); |
|
|
|
const width = 600; |
|
const height = 600; |
|
const legendHeight = 50; |
|
|
|
const svg = d3.select("#graph").select("svg"); |
|
svg.selectAll("*").remove(); |
|
svg.attr("viewBox", [0, 0, width, height + legendHeight]); |
|
|
|
const treemap = d3.treemap() |
|
.size([width, height]) |
|
.paddingOuter(3) |
|
.paddingTop(19) |
|
.paddingInner(3) |
|
.round(true); |
|
|
|
const root = d3.hierarchy(data) |
|
.sum(d => d.value); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
console.log('Treemap root:', root); |
|
|
|
treemap(root); |
|
|
|
const color = d => { |
|
switch (d.data.name) { |
|
|
|
case 'root': return 'rgb(225, 225, 225)'; |
|
case 'Total': return 'rgb(225, 225, 225)'; |
|
|
|
|
|
case 'Activation Memory': return 'rgb(61, 198, 159)'; |
|
case 'Parameters / Gradients / Optimizer States': return 'rgba(232, 137, 170, 0.85)'; |
|
|
|
|
|
case 'Parameters': return 'rgb(206, 192, 250)'; |
|
case 'Gradients': return 'rgb(227, 138, 66)'; |
|
case 'OptimizerAverages': return 'rgb(78, 165, 183)'; |
|
|
|
|
|
case 'Attention': return 'rgb(206, 192, 250)'; |
|
case 'Feedforward': return 'rgb(171, 232, 241)'; |
|
case 'LayerNorm': return 'rgb(232, 137, 171)'; |
|
|
|
|
|
case 'Dropout': return 'rgb(67, 145, 108)'; |
|
case 'Projection': return 'rgb(174, 214, 251)'; |
|
case 'Cross Entropy': return 'rgb(232, 137, 171)'; |
|
|
|
|
|
default: return 'rgb(227, 138, 66)'; |
|
}; |
|
}; |
|
|
|
if (d3.select('#tooltip').empty()) { |
|
d3.select('body') |
|
.append('div') |
|
.attr('id', 'tooltip') |
|
.style('opacity', 0) |
|
.style('position', 'absolute') |
|
.style('background-color', 'white') |
|
.style('padding', '4px') |
|
.style('font-size', '12px') |
|
.style('border-radius', '5px') |
|
.style('box-shadow', '0px 0px 5px 0px rgba(0,0,0,0.3)'); |
|
} |
|
|
|
const cell = svg.selectAll("g") |
|
.data(root.descendants().filter(d => d.depth !== 0)) |
|
.join("g") |
|
.attr("transform", d => `translate(${d.x0},${d.y0})`) |
|
.on('mouseover', (event, d) => { |
|
const name = d.data.name; |
|
const value = formatBytes(d.value); |
|
d3.select('#tooltip').transition().duration(200).text(`${name}: ${value}`) |
|
}) |
|
.on('mouseout', function() { |
|
d3.select('#tooltip').style('opacity', 0) |
|
}) |
|
.on('mousemove', function(event) { |
|
d3.select('#tooltip').style('left', (event.pageX + 10) + 'px').style('top', (event.pageY + 10) + 'px').style('opacity', 1) |
|
}); |
|
|
|
cell.append("rect") |
|
.attr("width", d => d.x1 - d.x0) |
|
.attr("height", d => d.y1 - d.y0) |
|
.attr("fill", d => color(d)) |
|
.attr("stroke", d => d.depth === 1 ? color(d) : "white") |
|
.attr("stroke-width", 1); |
|
|
|
const fontSize = 10; |
|
const padding = 2; |
|
|
|
cell.append("text") |
|
.attr("font-size", `${fontSize}px`) |
|
.attr("font-family", "sans-serif") |
|
.each(function (d) { |
|
|
|
const node = d3.select(this); |
|
|
|
const name = d.data.name; |
|
const value = formatBytes(d.value); |
|
|
|
if (d.depth === 1 || d.depth === 2) { |
|
node.attr("transform", `translate(${padding},${fontSize + padding})`) |
|
.attr("font-weight", "bold") |
|
.attr("font-size", 12) |
|
.text(`${name}: ${value}`); |
|
} else { |
|
|
|
node.attr("transform", `translate(${padding},${fontSize + padding})`) |
|
.text(name[0].toUpperCase()) |
|
.attr("font-weight", "bold") |
|
.append("title") |
|
.text(`${name}: ${value}`); |
|
} |
|
}); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
console.log('Treemap nodes created'); |
|
} |
|
|
|
function formatBytes(bytes) { |
|
const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB']; |
|
if (bytes === 0) return '0 Bytes'; |
|
const i = parseInt(Math.floor(Math.log(bytes) / Math.log(1024)), 10); |
|
return `${(bytes / (1024 ** i)).toFixed(2)} ${sizes[i]}`; |
|
} |
|
|
|
const presets = { |
|
"Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu", seq_parallel: false }, |
|
"Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false }, |
|
"Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false }, |
|
"Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false } |
|
}; |
|
|
|
function setPresetValues(preset) { |
|
if (preset === "custom") return; |
|
|
|
const values = presets[preset]; |
|
Object.keys(values).forEach(key => { |
|
const element = document.getElementById(key); |
|
const inputElement = document.getElementById(`${key}_input`); |
|
if (element) { |
|
if (element.type === 'checkbox') { |
|
element.checked = values[key]; |
|
} else { |
|
element.value = values[key]; |
|
} |
|
} |
|
if (inputElement) { |
|
inputElement.value = values[key]; |
|
} |
|
}); |
|
|
|
updateGraph(); |
|
} |
|
|
|
function syncSliderAndInput(sliderId, inputId) { |
|
const slider = document.getElementById(sliderId); |
|
const input = document.getElementById(inputId); |
|
|
|
slider.addEventListener('input', () => { |
|
input.value = slider.value; |
|
updateGraph(); |
|
}); |
|
|
|
input.addEventListener('input', () => { |
|
let value = parseInt(input.value); |
|
if (isNaN(value)) { |
|
value = parseInt(slider.min); |
|
} |
|
value = Math.max(parseInt(slider.min), Math.min(parseInt(slider.max), value)); |
|
slider.value = value; |
|
input.value = value; |
|
updateGraph(); |
|
}); |
|
} |
|
|
|
export const init_memory_plot = function () { |
|
console.log('Initializing memory plot'); |
|
|
|
const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp']; |
|
sliderIds.forEach(id => { |
|
const slider = document.getElementById(id); |
|
const input = document.getElementById(`${id}_input`); |
|
if (slider && input) { |
|
syncSliderAndInput(id, `${id}_input`); |
|
} else { |
|
console.warn(`Elements for ${id} not found`); |
|
} |
|
}); |
|
|
|
const recomputationSelect = document.getElementById('recomputation'); |
|
if (recomputationSelect) { |
|
recomputationSelect.addEventListener('change', updateGraph); |
|
} else { |
|
console.warn('Recomputation select not found'); |
|
} |
|
|
|
const ffActivationSelect = document.getElementById('ff_activation'); |
|
if (ffActivationSelect) { |
|
ffActivationSelect.addEventListener('change', updateGraph); |
|
} else { |
|
console.warn('FF Activation select not found'); |
|
} |
|
|
|
const zeroSelect = document.getElementById('zero'); |
|
if (zeroSelect) { |
|
zeroSelect.addEventListener('change', updateGraph); |
|
} else { |
|
console.warn('Zero select not found'); |
|
} |
|
|
|
const mixedCheckbox = document.getElementById('mixed'); |
|
if (mixedCheckbox) { |
|
mixedCheckbox.addEventListener('change', updateGraph); |
|
} else { |
|
console.warn('Mixed checkbox not found'); |
|
} |
|
|
|
const seqParallelCheckbox = document.getElementById('seq_parallel'); |
|
if (seqParallelCheckbox) { |
|
seqParallelCheckbox.addEventListener('change', updateGraph); |
|
} else { |
|
console.warn('Seq Parallel checkbox not found'); |
|
} |
|
|
|
const presetSelect = document.getElementById('presets'); |
|
if (presetSelect) { |
|
presetSelect.addEventListener('change', (event) => { |
|
setPresetValues(event.target.value); |
|
}); |
|
} else { |
|
console.warn('Preset select not found'); |
|
} |
|
|
|
|
|
sliderIds.forEach(id => { |
|
const slider = document.getElementById(id); |
|
if (slider) { |
|
switch (id) { |
|
case 'a': slider.max = '128'; break; |
|
case 'b': slider.max = '53248'; break; |
|
case 'h': slider.max = '16384'; break; |
|
case 'h_ff': slider.max = '65536'; break; |
|
case 'L': slider.max = '126'; break; |
|
case 's': slider.max = '128000'; break; |
|
case 'v': slider.max = '100000'; break; |
|
case 'k': slider.max = '16'; break; |
|
case 'tp': slider.max = '16'; break; |
|
case 'dp': slider.max = '256'; break; |
|
} |
|
} else { |
|
console.warn(`Slider ${id} not found`); |
|
} |
|
}); |
|
|
|
console.log('Adding svg'); |
|
const graphContainer = document.getElementById('graph'); |
|
if (graphContainer) { |
|
const svg = d3.select("#graph") |
|
.append("svg") |
|
} else { |
|
console.warn('Graph container not found'); |
|
} |
|
|
|
updateGraph(); |
|
}; |