import * as d3 from 'd3';

export function activationMemory(
    a, // attention heads
    b, // micro batch size
    h, // hidden dimension size
    h_ff, // feedforward dimension size (often h_ff = 4h)
    L, // number of layers
    s, // sequence length
    v, // vocab size
    tp = 1, // tensor model parallelism
    mixed = true,
    recomputation = "none",
    ff_activation = "relu",
    seq_parallel = false
) {
    console.log('activationMemory called with:', { a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel });
    // https://arxiv.org/pdf/2205.05198
    const bytesPerValue = mixed ? 2 : 4;

    let oneLayerAttention;
    if (recomputation === "none" || recomputation === "full") {
        if (seq_parallel) {
            oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1) + ((2 * bytesPerValue + 1) * a * s * s * b); // eq (2)
        } else {
            oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1) + ((2 * bytesPerValue + 1) * a * s * s * b / tp); // eq (2)
        }
    } else if (recomputation === "selective") {
        if (seq_parallel) {
            oneLayerAttention = s * b * h / tp * (bytesPerValue * 5 + 1); // table 2
        } else {
            oneLayerAttention = s * b * h * (bytesPerValue * 4 / tp + bytesPerValue + 1); // table 2
        }
    } else {
        throw new Error("Invalid recomputation value");
    }

    let oneLayerFeedforward;
    if (ff_activation === "relu") {
        if (seq_parallel) {
            oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
                + s * b * h / tp);  // dropout
        } else {
            oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
                + s * b * h);  // dropout
        }
    } else if (ff_activation === "gelu") {
        if (seq_parallel) {
            oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
                + s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
                + s * b * h / tp);  // dropout
        } else {
            oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of 1st/2nd linear layers
                + s * b * h_ff * bytesPerValue / tp // inputs of activation function (not really necessary for Relu)
                + s * b * h);  // dropout
        }
    } else if (ff_activation === "swiglu") {
        if (seq_parallel) {
            oneLayerFeedforward = (s * b * h * bytesPerValue / tp + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
                + s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
                + s * b * h / tp);  // dropout (note that dropout is lower-precision - boolean)
        } else {
            oneLayerFeedforward = (s * b * h * bytesPerValue + (s * b * h_ff * bytesPerValue / tp) // inputs of input/output linear layers
                + s * b * h_ff * bytesPerValue * 3 / tp // inputs of activation function
                + s * b * h);  // dropout (note that dropout is lower-precision - boolean)
        }
    }

    let layerNorm;
    if (seq_parallel) {
        layerNorm = s * b * h * bytesPerValue / tp;
    } else {
        layerNorm = s * b * h * bytesPerValue;
    }

    const inputDropout = seq_parallel ? s * b * h / tp : s * b * h; // section 4.3
    const outputLayerNorm = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
    const outputLayerProjection = seq_parallel ? s * b * h * bytesPerValue / tp : s * b * h * bytesPerValue;
    const outputCrossEntropy = seq_parallel ? s * b * v * 4 / tp : s * b * v * 4;  // In FP32


    let data
    if (recomputation === "none" || recomputation === "selective") {

        data = {
            name: "Activation Memory",
            children: [
                ...Array.from({ length: L }, (_, index) => ({
                    name: `Layer ${index + 1}`,
                    children: [
                        { name: 'Attention', value: oneLayerAttention },
                        { name: 'Feedforward', value: oneLayerFeedforward },
                        { name: 'LayerNorm', value: 2 * layerNorm },
                    ]
                })),
                { name: 'Dropout', value: inputDropout },
                { name: 'LayerNorm', value: outputLayerNorm },
                { name: 'Projection', value: outputLayerProjection },
                { name: 'Cross Entropy', value: outputCrossEntropy }
            ]
        };
    } else if (recomputation === "full") {
        data = {
            name: "Activation Memory",
            children: [
                { name: 'LayerInput', value: s * b * h * bytesPerValue * L },
                { name: 'Dropout', value: inputDropout },
                { name: 'LayerNorm', value: outputLayerNorm },
                { name: 'Projection', value: outputLayerProjection },
                { name: 'Cross Entropy', value: outputCrossEntropy }
            ]
        };
    } else {
        throw new Error("Invalid recomputation value");
    }

    return data;
}

export function paramGradsOpt(h, L, s, v, k = 8, dp = 1, zero = 0, mixed = true) {
    // h, # hidden dimension size
    // L, # number of layers
    // s, # sequence length
    // v, # vocab size
    // k=8, # parameters for optimizer (Adam: 8 = 4 bytes moments + 4 bytes variance)
    // dp=1, # data parallelism
    // zero = 0, 1, 2, 3, # zero data parallelism
    // mixed=True # mixed precision training
    console.log('paramGradsOpt called with:', { h, L, s, v, k, dp, zero, mixed });
    const emb = h * (v + s);
    const oneLayer = 12 * h ** 2 + 13 * h;
    const other = 2 * h;

    const n = emb + L * oneLayer + other;

    if (mixed) {
        k += 4;
    }
    const bytesPerParameter = mixed ? 2 : 4;

    const data = {
        name: "Parameters / Gradients / Optimizer States",
        children: [
            { name: 'Parameters', value: zero >= 3 ? bytesPerParameter * n / dp : bytesPerParameter * n },
            { name: 'Gradients', value: zero >= 2 ? bytesPerParameter * n / dp : bytesPerParameter * n },
            { name: 'OptimizerAverages', value: zero >= 1 ? k * n / dp : k * n }
        ]
    };
    console.log('paramGradsOpt result:', data);
    return data;
}

export function updateGraph() {
    console.log('updateGraph called');
    const a = +document.getElementById('a').value;
    const b = +document.getElementById('b').value;
    const h = +document.getElementById('h').value;
    const h_ff = +document.getElementById('h_ff').value;
    const L = +document.getElementById('L').value;
    const s = +document.getElementById('s').value;
    const v = +document.getElementById('v').value;
    const k = +document.getElementById('k').value;
    const tp = +document.getElementById('tp').value;  // New: t parameter
    const zero = document.getElementById('zero').value;
    const dp = document.getElementById('dp').value;
    const mixed = document.getElementById('mixed').checked;
    const recomputation = document.getElementById('recomputation').value;
    const ff_activation = document.getElementById('ff_activation').value;
    const seq_parallel = document.getElementById('seq_parallel').checked;

    console.log('Slider values:', { a, b, h, h_ff, L, s, v, k, tp, zero, dp, mixed, recomputation, ff_activation, seq_parallel });

    const activationMemoryData = activationMemory(a, b, h, h_ff, L, s, v, tp, mixed, recomputation, ff_activation, seq_parallel);
    const paramGradsOptValue = paramGradsOpt(h, L, s, v, k, dp, zero, mixed);

    const data = {
        name: "root",
        children: [
            {
                name: 'Total',
                value: 0,
                children: [
                    activationMemoryData,
                    paramGradsOptValue
                ]
            }
        ]
    };

    console.log('Data for treemap:', data);

    const width = 700;
    const height = 450;
    const legendHeight = 50;

    const svg = d3.select("#graph").select("svg");
    svg.selectAll("*").remove();
    svg.attr("viewBox", [0, 0, width, height + legendHeight]);

    const treemap = d3.treemap()
        .size([width, height])
        .paddingOuter(3)
        .paddingTop(19)
        .paddingInner(3)
        .round(true);

    const root = d3.hierarchy(data)
        .sum(d => d.value);
    // .sort((a, b) => b.value - a.value);

    // const fixedSize100GB = 100 * 1024 * 1024 * 1024; // 100GB in bytes
    // if (root.children[0].value < fixedSize100GB) {
    //     root.value = fixedSize100GB;
    //     root.children[0].value = fixedSize100GB;
    // }

    console.log('Treemap root:', root);

    treemap(root);

    const color = d => {
        switch (d.data.name) {
            // Root and Total (container levels)
            case 'root': return 'rgb(225, 225, 225)';  // Light Grey
            case 'Total': return 'rgb(225, 225, 225)';  // Light Grey
            
            // Give distinct colors to the main section containers
            case 'Activation Memory': return 'rgb(78, 165, 183)';  // Orange
            case 'Parameters / Gradients / Optimizer States': return 'rgb(232, 137, 171)';  // Teal Blue
    
            // Parameters / Gradients / Optimizer States branch
            case 'Parameters': return 'rgb(206, 192, 250)';  // Blue
            case 'Gradients': return 'rgb(227, 138, 66)';   // Orange
            case 'OptimizerAverages': return 'rgb(78, 165, 183)';  // Pink
            
            // activationMemory branch - Layer components
            case 'Attention': return 'rgb(206, 192, 250)';  // Purple
            case 'Feedforward': return 'rgb(171, 232, 241)';  // Light Blue
            case 'LayerNorm': return 'rgb(232, 137, 171)';  // Light Green
            
            // activationMemory branch - other components
            case 'Dropout': return 'rgb(67, 145, 108)';  // Dark Green
            case 'Projection': return 'rgb(174, 214, 251)';  // Sky Blue
            case 'Cross Entropy': return 'rgb(232, 137, 171)';  // Pink
            
            // Default for any Layer nodes and unexpected cases
            default: return 'rgb(227, 138, 66)';  // Light Grey
        };
      };

    if (d3.select('#tooltip').empty()) {
      d3.select('body')
      .append('div')
      .attr('id', 'tooltip')
      .style('opacity', 0)
      .style('position', 'absolute')
      .style('background-color', 'white')
      .style('padding', '4px')
      .style('font-size', '12px')
      .style('border-radius', '5px')
      .style('box-shadow', '0px 0px 5px 0px rgba(0,0,0,0.3)');
    }

    const cell = svg.selectAll("g")
        .data(root.descendants().filter(d => d.depth !== 0)) // Skip root node
        .join("g")
        .attr("transform", d => `translate(${d.x0},${d.y0})`)
        .on('mouseover', (event, d) => {
          const name = d.data.name;
          const value = formatBytes(d.value);
          d3.select('#tooltip').transition().duration(200).text(`${name}: ${value}`)
        })
        .on('mouseout', function() {
          d3.select('#tooltip').style('opacity', 0)
        })
        .on('mousemove', function(event) {
          d3.select('#tooltip').style('left', (event.pageX + 10) + 'px').style('top', (event.pageY + 10) + 'px').style('opacity', 1)
        });

    cell.append("rect")
        .attr("width", d => d.x1 - d.x0)
        .attr("height", d => d.y1 - d.y0)
        .attr("fill", d => color(d))
        .attr("stroke", d => d.depth === 1 ? color(d) : "white")
        .attr("stroke-width", 1);

    const fontSize = 10;
    const padding = 2;

    cell.append("text")
        .attr("font-size", `${fontSize}px`)
        .attr("font-family", "sans-serif")
        .each(function (d) {

            const node = d3.select(this);

            const name = d.data.name;
            const value = formatBytes(d.value);

            if (d.depth === 1 || d.depth === 2) {
                node.attr("transform", `translate(${padding},${fontSize + padding})`)
                    .attr("font-weight", "bold")
                    .attr("font-size", 12)
                    .text(`${name}: ${value}`);
            } else {
                // Child nodes
                node.attr("transform", `translate(${padding},${fontSize + padding})`)
                    .text(name[0].toUpperCase())  // Display only the first letter
                    .attr("font-weight", "bold")
                    .append("title")  // Add title for hover effect
                    .text(`${name}: ${value}`);
            }
        });

    /* 
    // Adjust legend positioning
    const legendData = root.children[0].children.concat(root.children[0]);
    const legend = svg.append("g")
        .attr("font-family", "sans-serif")
        .attr("font-size", 10)
        .attr("text-anchor", "start")
        .attr("transform", `translate(0, ${height})`)
        .selectAll("g")
        .data(legendData)
        .join("g")
        .attr("transform", (d, i) => `translate(${i * 240}, 0)`);

    legend.append("rect")
        .attr("x", 0)
        .attr("width", 19)
        .attr("height", 19)
        .attr("fill", d => color(d))
        .attr("stroke", '#f3f3f3')
        .attr("stroke-width", 0);

    legend.append("text")
        .attr("x", 24)
        .attr("y", 9.5)
        .attr("dy", "0.32em")
        .text(d => `${d.data.name}: ${formatBytes(d.value)}`);
    */
    console.log('Treemap nodes created');
}
    
function formatBytes(bytes) {
    const sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB', 'PB'];
    if (bytes === 0) return '0 Bytes';
    const i = parseInt(Math.floor(Math.log(bytes) / Math.log(1024)), 10);
    return `${(bytes / (1024 ** i)).toFixed(2)} ${sizes[i]}`;
}

const presets = {
    "Llama 3 Tiny": { a: 16, b: 3, h: 1024, h_ff: 4096, L: 1, s: 7, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "gelu", seq_parallel: false },
    "Llama 3 8B": { a: 32, b: 32, h: 4096, h_ff: 16384, L: 32, s: 256, v: 30522, k: 8, tp: 1, zero: "1", dp: 1, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
    "Llama 3 70B": { a: 64, b: 32, h: 8192, h_ff: 32768, L: 80, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false },
    "Llama 3 405B": { a: 128, b: 32, h: 16384, h_ff: 65536, L: 126, s: 256, v: 30522, k: 8, tp: 8, zero: "1", dp: 8, mixed: true, recomputation: "none", ff_activation: "swiglu", seq_parallel: false }
};

function setPresetValues(preset) {
    if (preset === "custom") return;

    const values = presets[preset];
    Object.keys(values).forEach(key => {
        const element = document.getElementById(key);
        const inputElement = document.getElementById(`${key}_input`);
        if (element) {
            if (element.type === 'checkbox') {
                element.checked = values[key];
            } else {
                element.value = values[key];
            }
        }
        if (inputElement) {
            inputElement.value = values[key];
        }
    });

    updateGraph();  // Add this line to ensure the graph updates when a preset is selected
}

function syncSliderAndInput(sliderId, inputId) {
    const slider = document.getElementById(sliderId);
    const input = document.getElementById(inputId);

    slider.addEventListener('input', () => {
        input.value = slider.value;
        updateGraph();
    });

    input.addEventListener('input', () => {
        let value = parseInt(input.value);
        if (isNaN(value)) {
            value = parseInt(slider.min);
        }
        value = Math.max(parseInt(slider.min), Math.min(parseInt(slider.max), value));
        slider.value = value;
        input.value = value;
        updateGraph();
    });
}

export const init_memory_plot = function () {
    console.log('Initializing memory plot');

    const sliderIds = ['a', 'b', 'h', 'h_ff', 'L', 's', 'v', 'k', 'tp', 'dp'];
    sliderIds.forEach(id => {
        const slider = document.getElementById(id);
        const input = document.getElementById(`${id}_input`);
        if (slider && input) {
            syncSliderAndInput(id, `${id}_input`);
        } else {
            console.warn(`Elements for ${id} not found`);
        }
    });

    const recomputationSelect = document.getElementById('recomputation');
    if (recomputationSelect) {
        recomputationSelect.addEventListener('change', updateGraph);
    } else {
        console.warn('Recomputation select not found');
    }

    const ffActivationSelect = document.getElementById('ff_activation');
    if (ffActivationSelect) {
        ffActivationSelect.addEventListener('change', updateGraph);
    } else {
        console.warn('FF Activation select not found');
    }

    const zeroSelect = document.getElementById('zero');
    if (zeroSelect) {
        zeroSelect.addEventListener('change', updateGraph);
    } else {
        console.warn('Zero select not found');
    }

    const mixedCheckbox = document.getElementById('mixed');
    if (mixedCheckbox) {
        mixedCheckbox.addEventListener('change', updateGraph);
    } else {
        console.warn('Mixed checkbox not found');
    }

    const seqParallelCheckbox = document.getElementById('seq_parallel');
    if (seqParallelCheckbox) {
        seqParallelCheckbox.addEventListener('change', updateGraph);
    } else {
        console.warn('Seq Parallel checkbox not found');
    }

    const presetSelect = document.getElementById('presets');
    if (presetSelect) {
        presetSelect.addEventListener('change', (event) => {
            setPresetValues(event.target.value);
        });
    } else {
        console.warn('Preset select not found');
    }

    // Set max values for sliders
    sliderIds.forEach(id => {
        const slider = document.getElementById(id);
        if (slider) {
            switch (id) {
                case 'a': slider.max = '128'; break;
                case 'b': slider.max = '53248'; break;
                case 'h': slider.max = '16384'; break;
                case 'h_ff': slider.max = '65536'; break;
                case 'L': slider.max = '126'; break;
                case 's': slider.max = '128000'; break;
                case 'v': slider.max = '100000'; break;
                case 'k': slider.max = '16'; break;
                case 'tp': slider.max = '16'; break;
                case 'dp': slider.max = '256'; break;
            }
        } else {
            console.warn(`Slider ${id} not found`);
        }
    });

    console.log('Adding svg');
    const graphContainer = document.getElementById('graph');
    if (graphContainer) {
        const svg = d3.select("#graph")
            .append("svg")
    } else {
        console.warn('Graph container not found');
    }

    updateGraph();
};