{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from utils import activation_memory, param_grads_opt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 48,
   "metadata": {},
   "outputs": [],
   "source": [
    "def activation_memory(\n",
    "    a, # attention heads\n",
    "    b, # micro batch size\n",
    "    h, # hidden dimension size\n",
    "    h_ff, # feedforward dimension size (often h_ff = 4h)\n",
    "    L, # number of layers\n",
    "    s, # sequence length\n",
    "    mixed=True,\n",
    "    recomputation=\"none\",\n",
    "    ff_activation=\"relu\"\n",
    "    ):\n",
    "    \n",
    "    # https://arxiv.org/pdf/2205.05198\n",
    "    if mixed:\n",
    "        bytes_per_value = 2 \n",
    "    else:\n",
    "        bytes_per_value = 4\n",
    "\n",
    "    one_layer_attention = s * b * h * (bytes_per_value * 5 + 1) + ((2 * bytes_per_value + 1) * a * s * s * b) # eq (2)\n",
    "\n",
    "    if ff_activation == \"relu\":\n",
    "        one_layer_feedforward = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of 1st/2nd linear layers\n",
    "                + s * b * h)  # dropout\n",
    "    elif ff_activation == \"gelu\":\n",
    "        one_layer_feedforward = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of 1st/2nd linear layers\n",
    "                + s * b * h_ff * bytes_per_value # inputs of activation function (not really necessary for Relu)\n",
    "                + s * b * h)  # dropout\n",
    "    elif ff_activation == \"swiglu\":\n",
    "        one_layer_feedforward = (s * b * h * bytes_per_value + (s * b * h_ff * bytes_per_value)   # inputs of input/output linear layers\n",
    "         + s * b * h_ff * bytes_per_value * 3 # inputs of activation function\n",
    "            + s * b * h)  # dropout (note that dropout is lower-precision - boolean)\n",
    "\n",
    "\n",
    "    layer_norm = s * b * h * bytes_per_value\n",
    "\n",
    "    if recomputation == \"none\":\n",
    "        one_layer =  one_layer_attention + one_layer_feedforward + 2 * layer_norm # eq (2)\n",
    "    elif recomputation ==\"selective\":\n",
    "        one_layer = s * b * h * 34 # eq (6)\n",
    "    elif recomputation ==\"full\":\n",
    "        one_layer = s * b * h * 2\n",
    "    else:\n",
    "        raise ValueError()\n",
    "    \n",
    "    input_dropout = s * b * h # section 4.3\n",
    "\n",
    "    total = L * one_layer + input_dropout\n",
    "        \n",
    "    return total\n",
    "\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "a = 16\n",
    "b = 3\n",
    "h = 1024\n",
    "h_ff = 4 * h\n",
    "L = 1\n",
    "s = 7  # 128000\n",
    "recomputation = \"none\"\n",
    "mixed = True\n",
    "ff_activation = \"swiglu\"\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1086960"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "activation_memory(a=a, b=b, h=h, h_ff=h_ff, L=L, s=s, recomputation=recomputation, mixed=mixed, ff_activation=ff_activation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "from math import log\n",
    "\n",
    "def format_bytes(bytes):\n",
    "    sizes = ['Bytes', 'KB', 'MB', 'GB', 'TB']\n",
    "    if bytes == 0:\n",
    "        return '0 Bytes'\n",
    "    i = int(log(bytes, 1024))\n",
    "    print(i)\n",
    "    p = 1024 ** i\n",
    "    s = round(bytes / p, 2)\n",
    "    return f\"{s} {sizes[i]}\"\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "'22.13 TB'"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "format_bytes(activation_memory(a, b, h, L, s, recomputation))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "jupyter",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}