{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "Ab4iZfp4eXPk" }, "source": [ "# Bike Rides and the Poisson Model\n", "\n", "To help the urban planners, you are called to model the daily bike rides in NYC using [this dataset](https://gist.github.com/sachinsdate/c17931a3f000492c1c42cf78bf4ce9fe/archive/7a5131d3f02575668b3c7e8c146b6a285acd2cd7.zip). The dataset contains date, day of the week, high and low temp, precipitation and bike ride counts as columns. \n", "\n" ] }, { "cell_type": "code", "source": [ "!wget https://gist.github.com/sachinsdate/c17931a3f000492c1c42cf78bf4ce9fe/archive/7a5131d3f02575668b3c7e8c146b6a285acd2cd7.zip\n", "!unzip 7a5131d3f02575668b3c7e8c146b6a285acd2cd7.zip" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "gM-y0BWye7WN", "outputId": "ba5d103a-ec03-42ec-8bf2-1094d2ed99ee" }, "execution_count": 1, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "--2023-02-26 21:10:52-- https://gist.github.com/sachinsdate/c17931a3f000492c1c42cf78bf4ce9fe/archive/7a5131d3f02575668b3c7e8c146b6a285acd2cd7.zip\n", "Resolving gist.github.com (gist.github.com)... 192.30.255.112\n", "Connecting to gist.github.com (gist.github.com)|192.30.255.112|:443... connected.\n", "HTTP request sent, awaiting response... 302 Found\n", "Location: https://codeload.github.com/gist/c17931a3f000492c1c42cf78bf4ce9fe/zip/7a5131d3f02575668b3c7e8c146b6a285acd2cd7 [following]\n", "--2023-02-26 21:10:53-- https://codeload.github.com/gist/c17931a3f000492c1c42cf78bf4ce9fe/zip/7a5131d3f02575668b3c7e8c146b6a285acd2cd7\n", "Resolving codeload.github.com (codeload.github.com)... 192.30.255.120\n", "Connecting to codeload.github.com (codeload.github.com)|192.30.255.120|:443... connected.\n", "HTTP request sent, awaiting response... 200 OK\n", "Length: unspecified [application/zip]\n", "Saving to: ‘7a5131d3f02575668b3c7e8c146b6a285acd2cd7.zip’\n", "\n", "7a5131d3f02575668b3 [ <=> ] 2.56K --.-KB/s in 0s \n", "\n", "2023-02-26 21:10:53 (27.7 MB/s) - ‘7a5131d3f02575668b3c7e8c146b6a285acd2cd7.zip’ saved [2623]\n", "\n", "Archive: 7a5131d3f02575668b3c7e8c146b6a285acd2cd7.zip\n", "7a5131d3f02575668b3c7e8c146b6a285acd2cd7\n", " creating: c17931a3f000492c1c42cf78bf4ce9fe-7a5131d3f02575668b3c7e8c146b6a285acd2cd7/\n", " inflating: c17931a3f000492c1c42cf78bf4ce9fe-7a5131d3f02575668b3c7e8c146b6a285acd2cd7/nyc_bb_bicyclist_counts.csv \n" ] } ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "id": "RxdI4hgDeXPr" }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "np.random.seed(0)\n", "sns.set_theme(style='whitegrid', palette='pastel')\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "source": [ "filename = \"c17931a3f000492c1c42cf78bf4ce9fe-7a5131d3f02575668b3c7e8c146b6a285acd2cd7/nyc_bb_bicyclist_counts.csv\"\n", "\n", "df = pd.read_csv(filename)\n", "df.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 206 }, "id": "ImGoQW97gLQN", "outputId": "9e481eb5-d4f1-4d14-c8b2-d128abdab273" }, "execution_count": 3, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " Date HIGH_T LOW_T PRECIP BB_COUNT\n", "0 1-Apr-17 46.0 37.0 0.00 606\n", "1 2-Apr-17 62.1 41.0 0.00 2021\n", "2 3-Apr-17 63.0 50.0 0.03 2470\n", "3 4-Apr-17 51.1 46.0 1.18 723\n", "4 5-Apr-17 63.0 46.0 0.00 2807" ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
DateHIGH_TLOW_TPRECIPBB_COUNT
01-Apr-1746.037.00.00606
12-Apr-1762.141.00.002021
23-Apr-1763.050.00.032470
34-Apr-1751.146.01.18723
45-Apr-1763.046.00.002807
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 3 } ] }, { "cell_type": "markdown", "metadata": { "id": "rdF7NfHKeXPo" }, "source": [ "## Maximum Likelihood I \n", " \n", "The obvious choice in distributions is the [Poisson distribution](https://en.wikipedia.org/wiki/Poisson_distribution) which depends only on one parameter, λ, which is the average number of occurrences per interval. We want to estimate this parameter using Maximum Likelihood Estimation.\n", "\n", "Implement a Gradient Descent algorithm from scratch that will estimate the Poisson distribution according to the Maximum Likelihood criterion. Plot the estimated mean vs iterations to showcase convergence towards the true mean. \n", "\n", "References: \n", "\n", "1. [This blog post](https://towardsdatascience.com/the-poisson-process-everything-you-need-to-know-322aa0ab9e9a). \n", "\n", "2. [This blog post](https://towardsdatascience.com/understanding-maximum-likelihood-estimation-fa495a03017a) and note the negative log likelihood function. \n" ] }, { "cell_type": "markdown", "source": [ "Negative log likelihood function for Poisson distribution\n", "\n", "$n \\lambda + \\sum_{i=1}^n ln(x_i!) - ln(\\lambda) \\sum_{i=1}^n x_i$\n", "\n", "Derivative of negative log likelihood in respect to lambda\n", "\n", "$n - {1 \\over \\lambda} \\sum_{i=1}^n x_i$" ], "metadata": { "id": "_TdBnxVrlRYH" } }, { "cell_type": "code", "source": [ "def gradient(l, x):\n", " n = len(x)\n", " lambda_gradient = n - (1 / l) * np.sum(x)\n", " return lambda_gradient\n", "\n", "def gradient_descent(x, l, learning_rate, max_iter):\n", " lam = l\n", " lams = [] # keeping track of updated lambda\n", " for _ in range(max_iter):\n", " g = gradient(lam, x)\n", "\n", " lams.append(lam)\n", " lam = lam - learning_rate * g\n", " return lam, lams" ], "metadata": { "id": "noR1cKD4hS2c" }, "execution_count": 18, "outputs": [] }, { "cell_type": "code", "source": [ "cyclists = np.array(df['BB_COUNT'])\n", "iterations = 100000\n", "alpha = 0.001\n", "lam = 1\n", "\n", "estimated, lams = gradient_descent(cyclists, lam, alpha, iterations)\n", "\n", "actual = (1 / len(cyclists)) * np.sum(cyclists)\n", "\n", "print(f\"Estimated mean={estimated}\")\n", "print(f\"Actual mean={actual}\")" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "WsUk4PlAmY7L", "outputId": "c4eea6f9-db9a-47cc-e856-128348dcc496" }, "execution_count": 39, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Estimated mean=2679.715270113102\n", "Actual mean=2680.042056074766\n" ] } ] }, { "cell_type": "code", "source": [ "sns.lineplot(x=range(len(lams)), y=lams)\n", "plt.xlabel('Iterations')\n", "plt.ylabel('Estimated mean')\n", "plt.show()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 285 }, "id": "Dz-7upxEoKno", "outputId": "41370016-06be-4c4c-fa64-7a5a294fe3f6" }, "execution_count": 38, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "iVBORw0KGgoAAAANSUhEUgAAAZkAAAEMCAYAAAAWDss+AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMywgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/NK7nSAAAACXBIWXMAAAsTAAALEwEAmpwYAAAqRklEQVR4nO3de3wU9b0//tfMJru5uyQhsAkRNFy+0RwFEgS10RpEQLlpRSgP8FaQgw+prQWPLZjQBKsBjtpWPKBH6dFy5FF+cMIBLFELVGuPGNKApIb71SQk5EqSTfYy8/n9sWQlQMJsYPb6ej4emN357GTfn5jMa+fzmYskhBAgIiLSgezrAoiIKHgxZIiISDcMGSIi0g1DhoiIdMOQISIi3YT5ugBfUFUVbW1tCA8PhyRJvi6HiCggCCHgcDgQHR0NWda2jxKSIdPW1obDhw/7ugwiooA0dOhQxMbGanptSIZMeHg4ANcPymg0erx+eXk5MjIyrndZfo19Dg2h1udQ6y9wbX222+04fPiwexuqRUiGTOcQmdFohMlk6tX36O16gYx9Dg2h1udQ6y9w7X32ZJqBE/9ERKQbhgwREemGIUNERLphyBARkW4YMkREpJuQPLqMiCgYablzi7dv7sKQIQpAQggIAKoKqAJQBKAKAVV8v0wVAqra2XZ5uxACKlwbHddzoF7tg39WO77//heWd3l8oc29Hr5ffvFz4Pt1cfEy92PX98Elr+n+9bjo9drXxcWPL+FQBuNYmbXrSrjy67vbOF+2uIf367Jc6+uu8KJryQkTBiHrGtb3FEOG6DpTVAGHAjgvfHUoAg5FwKnC/dWpCigq4FRdr1cufHU9v+ixuHi563XqhY27Pvrh7Bn7ZUslAJLk+idf9FiSJEgA5M7nFx7jwmPpwoMLi9yPceF7dH7tXNZ5pRLpwn+lbl8vdV3W5bF02euv1B8AqKtrQKK5r/s5pCu/7rL1e/m673um7f2kHhq1v2dX9VWVABJ6ftF1xJAhuoQQAnYFsIlw1LUqsDkvPHcK2J2iy3Ob80KAKIDjQqh4EgCyBBhkwCBLCJNdj8MkCQYZCDcAEbKEMFmCQXYtM0iALEuur5JroyxLkuux9H1753P5oueGS55fHAqdYfHN/n0YMWJ4l1BxPQ7Oa/yVNtQg86YBvi7Dq0rPtnj1/RgyFDLsioDVJmB1CHRc+NfuFGi3C3Q4BdovWu7KiTQc/bbjsu8TJgPGMAmmMAlGAxBrkhFuAMIMEsINEsJlINwgIcyALs/dyzoD40JA+BODpCLc4F81UWBjyFBQEMIVHi0dAm12FVa7QJtNoM3uet5md+1lXEqWgIhwCZHhEqKNEhKiZUSGSTCFS6g+cxLDBt8EY5jUJVQMMjfCRFoxZChgCCHQ4QSa21W0dKg4bxOurx0qWmyu+YqLmcKAaKOMWJOMfrESok0Soo0yIi+ESmS4hHBD90NB7ZXnMaAP/0SIrgX/gsgvORSBpnYVjVYVTe0qmqwqGttV2Jzfv0aWgFiThNgIGZY4CXERMmIjZMQYJUQZJYRx2IfI5xgy5HNOVaCxTUVdm4q6NgX1bSrOd3w/ex4mA+ZIGal9wtAnUsYNkTLiIlxB4m9zGkTUFUOGvK7dIVDbouDseQXnWl17KJ3nIESGS0iMlnFzgow+UTLMUa49k2A9uoko2DFkSHc2p0B1s4KzLQpqzitovrCXEiYDiTEybu0fjsQYGYnRMqKMvNIRUTDxSsg0NjbixRdfxOnTp2E0GjFw4EDk5+cjPj4ew4YNw9ChQ933i16xYgWGDRsGANi5cydWrFgBRVFw66234tVXX0VkZORV28i3hBBotKqobFbwXZOCulYVAkC4DCTFGpCWKKNfnAEJUTJkHqlFFNS88rFRkiTMnTsXxcXF2Lp1K1JTU7Fq1Sp3+4YNG7BlyxZs2bLFHTBtbW14+eWXsWbNGnz66aeIjo7Ge++9d9U28g1VCJw9r2DPSRv+v33t2PbPDpR954AqgH9JDsfE9AjMyIzC2GERyEg2om+MgQFDFAK8EjJmsxmjR492Px8+fDiqqqp6XOfzzz9HRkYGBg0aBACYOXMm/vznP1+1jbxHCIGaFlewbNrXjk8OduBonRN9Y2TcfZMR00dE4aFbIzF8gBF9Yw2cpCcKQV6fk1FVFR999BFycnLcy+bMmQNFUXDPPfdg4cKFMBqNqK6uRnJysvs1ycnJqK6uBoAe20h/bXYVx+qcOHbOiRabgEEGUm4wYFB8GFLMBp4xTkRuXg+ZgoICREVFYfbs2QCA3bt3w2KxoLW1FYsXL8bq1avx85//3Cu1lJeX93rd0tLS61iJ/xMC2LX3EBpVM1oRDUBCNNqQIjUjFi0wnBeoPw/U+7rQ6yzU/j8DodfnUOsv4N0+ezVkCgsLcerUKaxZs8Y90W+xWAAAMTExmD59OtatW+devmfPHve6VVVV7tf21OaJjIwMmEwmj9crLS1FZmamx+sFIrsicOycE/tOt8IBI6KMEm5LDENaYhhiI6IBJPm6RN2E0v/nTqHW51DrL3BtfbbZbB5/OPfa8aKvv/46ysvLsXr1ahiNRgBAc3MzOjpcFyB0Op0oLi5Geno6ACA7OxsHDhzAyZMnAbgODpg4ceJV2+j6aLer2Hvahk1lVpSctiMcTtw72IRHbnfNscRG8FBjIro6r+zJHDlyBGvXrsWgQYMwc+ZMAMCAAQMwd+5c5ObmQpIkOJ1OjBgxAs8//zwA155Nfn4+5s+fD1VVkZ6ejiVLlly1ja5Nu11FebUDh885oarAwHgDbukfjlOHDmJgfF9fl0dEAcYrITNkyBAcOnToim1bt27tdr37778f999/v8dt5Dm7U+CbKgcO1bgOO745MQz/khyOuAt7LKd8XB8RBSae8R/iVFXg8Dkn9lfaYXMCNyeE4baU78OFiOhaMGRCWFWzE1+fsuN8h0D/OBlZqUbERxt8XRYRBRGGTAjqcAjsPW3D8XoFsSYJOUNMSDEbeBFKIrruGDIhRAiB4/VOlJyyw6kCtyWH41+Sw3mnRyLSDUMmRHQ4BL46acPpRgV9Y2TcOcgEcxTnXYhIXwyZEFDdrOBvx22wOQVGpobj1v7hHBojIq9gyAQxIQT2VzrwTZUDcREScoZGIIET+0TkRQyZIGV3CnxxzIbKZgVpiWEYPdDIe94TkdcxZIJQU7uKXYc70GoXGD3QiKFJYRweIyKfYMgEmZrzCnYd6YAsS3jg/0WgXyyHx4jIdxgyQeRkgxN/O2ZDrEnC2GERiDHx6DEi8i2GTJA4WOPA16fs6BsjI2doBExhHB4jIt9jyASBb886sPe0HalmA7IHmxDGkyuJyE8wZALct9UO7D1jx419DLgnzQSZAUNEfoQhE8AqzroCZmAfA7IZMETkhzgzHKCO1zlRctq1B8OAISJ/xZAJQFXNTnx5woZ+sTIDhoj8GkMmwNS3Kdh9xAZzpIz7hkTwCspE5NcYMgGk3SGw64gNxjAJY4eaYORhykTk5xgyAUJRBf56tAM2p8B9Q0yIMvJ/HRH5P26pAkTJaTtqW1TcdZOJV1ImooDBkAkAx+ucOFzrxK2WcNyUwKPOiShwMGT8XEuHij0nbUiKkTFiQLivyyEi8ghDxo8pqsDnx2yQJOAHaSbIvFw/EQUYhowf21/pQH2bax6GV1QmokDELZefqmtV8M9qBwb3DcON8ZyHIaLAxJDxQ4oq8PcTNkSGS8hKNfq6HCKiXmPI+KHyagea2gXGDDLyhEsiCmheCZnGxkbMmzcP48ePx+TJk/Hcc8+hoaEBALBv3z5MmTIF48ePx9NPP436+nr3er1tC2TN7SoOVDlwU4IBA/pwmIyIAptXQkaSJMydOxfFxcXYunUrUlNTsWrVKqiqisWLFyM3NxfFxcXIysrCqlWrAKDXbYFMCIGS03YYZCDrRpOvyyEiumZeCRmz2YzRo0e7nw8fPhxVVVUoLy+HyWRCVlYWAGDmzJnYsWMHAPS6LZB916SgqlnB8BQjIsM5TEZEgc/rczKqquKjjz5CTk4OqqurkZyc7G6Lj4+HqqpoamrqdVugUlTXXswNkRKGJXGYjIiCg0dbs/r6elit1i7LUlNTPXrDgoICREVFYfbs2fj00089Wvd6Ky8v7/W6paWl17ESoE6NR6tIwkD5NMrKrFdfwQeud58DAfsc/EKtv4B3+6wpZD7//HMsWbIEdXV1EEK4l0uShIqKCs1vVlhYiFOnTmHNmjWQZRkWiwVVVVXu9oaGBsiyDLPZ3Os2T2RkZMBk8nzuo7S0FJmZmR6v1x27U2DzfitSYgy4d1j6dfu+19P17nMgYJ+DX6j1F7i2PttsNo8/nGsaLsvPz8ezzz6LsrIyHDx40P3Pk4B5/fXXUV5ejtWrV8NodJ37kZGRgY6ODuzduxcAsGHDBkyYMOGa2gLRP6sdsCvgtcmIKOho2pM5f/48Zs6cCamX1846cuQI1q5di0GDBmHmzJkAgAEDBmD16tVYsWIF8vLyYLPZkJKSgpUrVwIAZFnuVVugaXcIVNQ4MCjegHhewp+IgoymkPnRj36ETZs24dFHH+3VmwwZMgSHDh26YtvIkSOxdevW69oWSA5U2aGowPABPLOfiIKPppDZv38/PvzwQ7z77rtITEzs0rZ+/XpdCgsF7Q6BI7VO3JwYhrgIXnyBiIKPppCZPn06pk+frnctIedgjQOKADIsnIshouCkKWQefvhhvesIOXZF4FCNAzf2MeCGSO7FEFFw0nyeTF1dHb755hs0NjZ2OYy5t/M0oe5wreuIMu7FEFEw0xQyn332GRYvXoyBAwfi6NGjGDx4MI4cOYKRI0cyZHpBVQUqzjrRP05GYgyPKCOi4KUpZN5880385je/wcSJEzFq1CgUFRVh06ZNOHr0qN71BaXTjQraHQJ33sQjyogouGmaDKiqqsLEiRO7LHv44YdRVFSkR01B71CtAzEmCck3cC+GiIKbppBJSEhAXV0dACAlJQVlZWU4ffo0VFXVtbhg1GhVUdOiYmhSGORentxKRBQoNIXM9OnT3RdUe/LJJ/H4449j6tSp+PGPf6xrccHoUK0DBgkYnMgJfyIKfprmZJ555hn342nTpuGOO+5Ae3s70tLSdCssGDkUgeN1TgxKCEME7xdDRCFA8yHMDocD+/fvR21tLR588EFYrVZYrVZERUXpWV9QOdXghFMFhvTl/WKIKDRo2todOnQICxYsgNFoRE1NDR588EGUlJTgf/7nf/Dmm2/qXGLwOFbnRKxJQt8YnnxJRKFB09Zu2bJl+OlPf4odO3YgLMyVS6NGjQrJm/30VqvNNeGflhjW66tZExEFGk0hc/ToUUydOhUA3BvIqKgo2Gw2/SoLMsfrnACAmxM5VEZEoUNTyKSkpFx2N7RvvvkGN954oy5FBRshBI7VOdEvVkaMiUNlRBQ6NH2sfv755zF//nzMnDkTDocDa9euxYYNG1BQUKB3fUGhwaqixSaQkczDlokotGj6WH3ffffhP//zP9HQ0IBRo0ahsrISv//97/GDH/xA7/qCwqkGBRKAVDOHyogotGje6t1yyy1YtmyZjqUEJyEETjW4LobJc2OIKNRoChmn04lt27ahoqICVqu1SxuHzHrW1C7QYhO4lZf0J6IQpClkFi9ejMOHD+Oee+5BQkKC3jUFlVMNTtdQWR8OlRFR6NG05fviiy+we/duxMTE6F1P0DnV4DqqLJJDZUQUgjRN/A8ePBjNzc161xJ0WjpUNHcIDOBeDBGFKE1bv5UrV2Lp0qW4++67kZiY2KVt2rRpetQVFCqbFQDAADPvG0NEoUlTyGzevBl79+5Fc3MzIiIi3MslSWLI9KCySUGsSUJcBE/AJKLQpClkPvjgAxQVFfHS/h5wqgJnzysYksShMiIKXZo+YicmJsJisehdS1CpOa9AEUAKb7FMRCFM08fsJ554AosXL8a8efMuO4Q5NTVVl8IC3XdNCgwy0D+OIUNEoUtTyOTn5wMA/vKXv3RZLkkSKioqrn9VQeDseQX9Yw0wyDx0mYhCl6aQOXjw4DW/UWFhIYqLi1FZWYmtW7di6NChAICcnBwYjUaYTCYAwKJFi5CdnQ0A2LdvH3Jzc2Gz2ZCSkoKVK1e696R6avM1q9116HJaX074E1Fo89pWcOzYsVi/fj1SUlIua/vd736HLVu2YMuWLe6AUVUVixcvRm5uLoqLi5GVlYVVq1Zdtc0f1LSoAID+sRwqI6LQ5rWQycrK8ujggfLycphMJmRlZQEAZs6ciR07dly1zR+cPa8g3ADER3NPhohCm18cX7to0SIIIZCZmYkXXngBcXFxqK6uRnJysvs18fHxUFUVTU1NPbaZzWYf9KCrmhYF/WINkHmbZSIKcT4PmfXr18NiscBut+OVV15Bfn6+14a+Lr3bpydKS0uvuNwhwnBeHYxIWzVKSxt6/f39UXd9Dmbsc/ALtf4C3u2zz0OmcwjNaDRi1qxZWLBggXt5VVWV+3UNDQ2QZRlms7nHNk9kZGS4DzjwRGlpKTIzM6/YdrzOicPHbci65UYkRN/k8ff2Vz31OVixz8Ev1PoLXFufbTabxx/Ouw2Ze++9F5KG4Z7du3d79IYXs1qtUBQFsbGxEELg448/Rnp6OgBXAHR0dGDv3r3IysrChg0bMGHChKu2+dq5VgVhMtAnivMxRETdhszKlSvdjw8cOICioiLMmTMHycnJqKqqwh//+EePrlu2fPlyfPLJJ6irq8NTTz0Fs9mMNWvWYOHChVAUBaqqIi0tDXl5eQAAWZaxYsUK5OXldTlM+WptvlbXpiIxWuZ8DBERegiZO+64w/04Pz8f7733Hvr16+deds8992Du3Ll4+umnNb3R0qVLsXTp0suWFxUVdbvOyJEjsXXrVo/bfMWpCjRYVdzan3fBJCICNB7CXFtbi6ioqC7LoqKiUFNTo0tRgaqhTYUQQN8YDpUREQEaJ/5zcnKwYMECLFiwAP3790d1dTXWrl2LnJwcvesLKOdaXSdhJsbwJEwiIkBjyPz617/G73//e+Tl5aG2thZJSUmYMGECnnvuOb3rCyjnWhXEmCTeapmI6AJNIWMymbBo0SIsWrRI73oCWl2riqRYDpUREXXSfJ7Ml19+ie3bt6OhoQFr1qzBgQMH0NraijvvvFPP+gJGu0PA6hBIjOZQGRFRJ00fuz/88EMsW7YMgwYNQklJCQAgIiICv/3tb3UtLpA0WhUAvF4ZEdHFNG0R/+u//gvr1q3DM888A1l2rXLzzTfjxIkTuhYXSBraXJP+PAmTiOh7mraIbW1t7su/dF4FwOl0Ijyc54N0arCqiDZKMIVx0p+IqJOmkBk1ahTeeeedLss++OADjB49WpeiAlGjVUU892KIiLrQNPG/dOlS/Ou//is2btyItrY2jB8/HtHR0Vi7dq3e9QUEhyLQ3CEwMJ4hQ0R0MU0hk5SUhE2bNuHAgQOorKyExWLBbbfd5p6fCXVN7a75GE76ExF1pWmruGDBAkiShNtuuw0TJ07E8OHDIcsyT8a8oNHKSX8ioivRtFXcs2fPFZd//fXX17WYQNXUriJMBmKMnPQnIrpYj8NlnefBOByOy86JOXPmTJdbIIey5nYVN0TKmu6/Q0QUSnoMmbNnzwIAhBDux50sFgsWLlyoX2UBpLlDoH8sz/QnIrpUjyHz6quvAgBGjBiBxx57zCsFBRqHImC1C9wQyb0YIqJLaTq6rDNgWltb0djY2KUtNTX1+lcVQJovHFl2QwQn/YmILqUpZI4dO4Zf/OIXOHjwICRJghDCPf9QUVGha4H+rrnjQshEMmSIiC6lacu4bNkyjB49Gl9//TViYmJQUlKCGTNm4LXXXtO7Pr/X3C4gSUCsicNlRESX0hQyBw8exKJFixAXFwchBGJjY/Hiiy/yKsxwDZfFmSTIMkOGiOhSmkLGZDLB6XQCAPr06YOqqiqoqoqmpiY9awsI520q4jhURkR0RZq2jpmZmfjzn/8MABg/fjzmzZuHOXPmYMyYMboW5++EEGjtEBwqIyLqhqaJ/4uHxV544QUMHjwYVqsV06ZN06uugNDuEFAEEGPingwR0ZVovv1yJ1mWQz5cOrXaBABO+hMRdUdTyLS0tOCDDz5ARUUFrFZrl7b3339fl8ICQYvNdfhyLPdkiIiuSFPIPP/881AUBePGjYPJZNK7poDR0uHak4nmngwR0RVpCpl9+/bhq6++gtFo1LuegNJqE4gySjDw8GUioivSfHTZ8ePH9a4l4LTYVM7HEBH1QNOezGuvvYZ58+bh9ttvR0JCQpc2LTcuKywsRHFxMSorK7F161YMHToUAHDixAm89NJLaGpqgtlsRmFhIQYNGnRNbd7UahNIvoFXXyYi6o6mPZk33ngDZ8+eRX19PU6dOuX+d/r0aU1vMnbsWKxfvx4pKSldlufl5WHWrFkoLi7GrFmzkJube81t3qKoAu0OgRjuyRARdUvTnsz27dtRXFyMpKSkXr1JVlbWZcvq6+vx7bffYt26dQCASZMmoaCgAA0NDRBC9KotPj6+V/X1htV+YdKfd8MkIuqWppBJTU1FWJjHp9T0qLq6Gv369YPB4BpuMhgMSEpKQnV1NYQQvWrzRchEGXn4MhFRdzQlx9SpU/Hss89i9uzZl83J3HnnnboU5g3l5eW9X/fQcQDJOHnkW1RL9utXlB8rLS31dQlexz4Hv1DrL+DdPmsKmfXr1wMAXn/99S7LJUnCX/7yl169scViQU1NDRRFgcFggKIoqK2thcVigRCiV22eysjI6NV5P6Wlpeg3YCAqzzgwakQGwg3BP2RWWlqKzMxMX5fhVexz8Au1/gLX1mebzebxh3NNIbNz585eFdSThIQEpKenY9u2bZg6dSq2bduG9PR095BXb9u8xWoXCDcgJAKGiKi3ru9ESzeWL1+OTz75BHV1dXjqqadgNpuxfft2LFu2DC+99BLefvttxMXFobCw0L1Ob9u8xWp3nYhJRETd6zZkJk6c6L68/7333uu+3fKldu/efdU3Wbp0KZYuXXrZ8rS0NGzcuPGK6/S2zVusdoGocIYMEVFPug2ZgoIC9+OVK1d6pZhA0mbniZhERFfTbchcfG5LfX09Jk6ceNlrduzYoU9Vfk4IoMMheI4MEdFVaDrJY8mSJVdc7osz7f2BE2EQACI5XEZE1KMeJ/7PnDkDwHWb4c7HF7eF6lWZnXANk0UwZIiIetRjyIwbNw6SJEEIgXHjxnVpS0xMxMKFC3Utzl8pDBkiIk16DJmDBw8CAGbPno0//vGPXikoEDiF68cWGcaQISLqiaY5mUsD5syZM/juu+90KSgQcLiMiEgbTSHzwgsv4B//+AcAYNOmTXjooYcwadIkn5+r4isKwiBLQDiPYCYi6pGmkPm///s/ZGRkAAD+8Ic/YN26ddi4cSPeffddXYvzV04YEBEudXuCKhERuWi6rIzD4YDRaERNTQ2amprcF1erq6vTtTh/5RRhiOB8DBHRVWkKmfT0dKxduxaVlZX44Q9/CACoqalBTEyMnrX5LeXCngwREfVM03DZK6+8gsOHD8Nms+FnP/sZAKCsrAyTJ0/Wsza/5QT3ZIiItNC0J3PjjTfi3//937ssmzBhAiZMmKBLUf7ONSfj6yqIiPxfj3syy5cv7/L80qPJQvFkTEUVEJBh4p4MEdFV9Rgymzdv7vL80qsxf/nll9e/Ij9ncwoAgJE3KyMiuqoeQ0YI0ePzUGRXXF+N3JMhIrqqHkPm0vNAeF4IYHfvyfi4ECKiANDjxL+iKPjqq6/cezBOp7PLc1VV9a/Qz9iVCyHDPRkioqvqMWQSEhLwq1/9yv3cbDZ3eR4fH69fZX7K7nR9NXFOhojoqnoMmZ07d3qrjoDhnvjnngwR0VVpOhmTvudQOCdDRKQVQ8ZDNqeADBWyzD0ZIqKrYch4yK4AMhRfl0FEFBAYMh6yOwUMDBkiIk0YMh5yqoAMnpRKRKQFQ8ZjDBgiIq0YMr3CoCEi0oIhQ0REutF0Pxm95eTkwGg0wmQyAQAWLVqE7Oxs7Nu3D7m5ubDZbEhJScHKlSuRkJAAAD226Yn7MERE2vnNnszvfvc7bNmyBVu2bEF2djZUVcXixYuRm5uL4uJiZGVlYdWqVQDQY5s38AwZIiJt/CZkLlVeXg6TyYSsrCwAwMyZM7Fjx46rthERkf/wi+EywDVEJoRAZmYmXnjhBVRXVyM5OdndHh8fD1VV0dTU1GOb2WzWt1COlxERaeYXIbN+/XpYLBbY7Xa88soryM/Px7hx43R/3/Lyco/XaVFSAUgoLS29/gX5OfY5NIRan0Otv4B3++wXIWOxWAAARqMRs2bNwoIFC/D444+jqqrK/ZqGhgbIsgyz2QyLxdJtmycyMjLcBxtoVVfRjpaWVmRmZnq0XqArLS1ln0NAqPU51PoLXFufbTabxx/OfT4nY7Va0dLSAsB1e+ePP/4Y6enpyMjIQEdHB/bu3QsA2LBhAyZMmAAAPbYREZH/8PmeTH19PRYuXAhFUaCqKtLS0pCXlwdZlrFixQrk5eV1OUwZQI9tRETkP3weMqmpqSgqKrpi28iRI7F161aP24iIyD/4fLgs0IiL/ktERD1jyPQCT8YkItKGIeMp7sQQEWnGkPEQM4aISDuGDBER6YYhQ0REumHIEBGRbhgyRESkG4aMhwRn/omINGPI9AqThohIC4YMERHphiHTCzzjn4hIG4YMERHphiHjIc7GEBFpx5AhIiLdMGSIiEg3DBlPcbyMiEgzhoyHmDFERNoxZIiISDcMGSIi0g1DxkMCgMRBMyIiTRgyRESkG4YMERHphiHjKY6UERFpxpAhIiLdMGQ8JLgrQ0SkGUOGiIh0w5AhIiLdBHTInDhxAjNmzMD48eMxY8YMnDx50tclERHRRQI6ZPLy8jBr1iwUFxdj1qxZyM3N1f09eTImEZF2ARsy9fX1+PbbbzFp0iQAwKRJk/Dtt9+ioaHBx5UREVGnMF8X0FvV1dXo168fDAYDAMBgMCApKQnV1dWIj4/X9D3Ky8s9fl+nkgqT5ERpaanH6wY69jk0hFqfQ62/gHf7HLAhcz1kZGTAZDJ5tM7tqkDZP/6BzMxMnaryT6WlpexzCAi1Podaf4Fr67PNZvP4w3nADpdZLBbU1NRAURQAgKIoqK2thcVi0fV9w2QJsqTrWxARBY2ADZmEhASkp6dj27ZtAIBt27YhPT1d81AZERHpL6CHy5YtW4aXXnoJb7/9NuLi4lBYWOjrkoiI6CIBHTJpaWnYuHGjr8sgIqJuBOxwGRER+T+GDBER6YYhQ0REugnoOZneEsJ1WRi73d7r72Gz2a5XOQGDfQ4NodbnUOsv0Ps+d24zO7ehWkjCk1cHiZaWFhw+fNjXZRARBaShQ4ciNjZW02tDMmRUVUVbWxvCw8MhSTyzkohICyEEHA4HoqOjIcvaZltCMmSIiMg7OPFPRES6YcgQEZFuGDJERKQbhgwREemGIUNERLphyBARkW4YMkREpBuGjAdOnDiBGTNmYPz48ZgxYwZOnjzp65I0aWxsxLx58zB+/HhMnjwZzz33HBoaGgAA+/btw5QpUzB+/Hg8/fTTqK+vd6+nR5svvPXWWxg2bJj7Kg/B3GebzYa8vDw88MADmDx5Ml5++WUAPf/u6tHmTbt27cK0adMwdepUTJkyBZ988sk11e5vfS4sLEROTk6X32Ff9K/XfRek2Zw5c0RRUZEQQoiioiIxZ84cH1ekTWNjo/jqq6/cz1977TXxy1/+UiiKIu6//35RUlIihBBi9erV4qWXXhJCCF3afKG8vFz85Cc/Effdd584dOhQ0Pe5oKBAvPLKK0JVVSGEEOfOnRNC9Py7q0ebt6iqKrKyssShQ4eEEEJUVFSI4cOHC0VRgqbPJSUloqqqyv07rGcf9Og7Q0ajuro6kZmZKZxOpxBCCKfTKTIzM0V9fb2PK/Pcjh07xBNPPCH2798vHnroIffy+vp6MXz4cCGE0KXN22w2m3jsscfEmTNn3H+gwdzn1tZWkZmZKVpbW7ss7+l3V482b1JVVdxxxx1i7969Qgghvv76a/HAAw8EZZ8vDhlv9+9a+h6SV2HujerqavTr1w8GgwEAYDAYkJSUhOrqasTHx/u4Ou1UVcVHH32EnJwcVFdXIzk52d0WHx8PVVXR1NSkS5vZbPZKHzv99re/xZQpUzBgwAD3smDu85kzZ2A2m/HWW29hz549iI6OxvPPP4+IiIhuf3eFENe9zZt/D5Ik4c0338Szzz6LqKgotLW14Z133unx7zXQ+wz0vD3So3/X0nfOyYSYgoICREVFYfbs2b4uRVdlZWUoLy/HrFmzfF2K1yiKgjNnzuCWW27B5s2bsWjRIixcuBBWq9XXpenG6XRi7dq1ePvtt7Fr1y78x3/8B372s58FdZ8DDfdkNLJYLKipqYGiKDAYDFAUBbW1tbBYLL4uTbPCwkKcOnUKa9asgSzLsFgsqKqqcrc3NDRAlmWYzWZd2ryppKQEx44dw9ixYwEAZ8+exU9+8hPMmTMnaPtssVgQFhaGSZMmAQBuv/129OnTBxEREd3+7gohrnubN1VUVKC2thaZmZkAgMzMTERGRsJkMgVtn4Get0d69O9a+s49GY0SEhKQnp6Obdu2AQC2bduG9PT0gBkqe/3111FeXo7Vq1fDaDQCADIyMtDR0YG9e/cCADZs2IAJEybo1uZNzzzzDP72t79h586d2LlzJ/r374/33nsPc+fODdo+x8fHY/To0fjyyy8BuI4Gqq+vx6BBg7r93e3p97q3bd7Uv39/nD17FsePHwcAHDt2DPX19Rg4cGDQ9hnoeXvk7baruqaZqBBz9OhR8eijj4oHHnhAPProo+LYsWO+LkmTw4cPi6FDh4oHHnhATJkyRUyZMkU8++yzQgghSktLxaRJk8S4cePEk08+6T4aSa82X7l40jSY+3z69Gkxe/ZsMWnSJDFt2jSxe/duIUTPv7t6tHnTli1bxKRJk8TkyZPF5MmTxaeffnpNtftbnwsKCkR2drZIT08Xd911l3jwwQd90r/e9p33kyEiIt1wuIyIiHTDkCEiIt0wZIiISDcMGSIi0g1DhoiIdMOQIfJzI0aMwJkzZ3xdBlGvMGSIriInJwd///vfsXnzZvz4xz/W9b3mzJmDjRs3dllWVlaG1NRUXd+XSC8MGSIvcTqdvi6ByOsYMkQaHDt2DHl5edi3bx9GjBiBrKwsAIDdbkdhYSF++MMf4q677kJubi46OjoAAHv27ME999yDd955B3fffTd++ctform5GfPnz8eYMWMwatQozJ8/H2fPngUAvPHGG9i7dy/y8/MxYsQI5OfnAwCGDRuGU6dOAQBaWlrw4osvYsyYMbjvvvvw9ttvQ1VVAHDvaRUWFmLUqFHIycnBX//6V3cfNm/ejLFjx2LEiBHIycnB//7v/3rt50ehiyFDpEFaWhp+/etfY/jw4SgrK3Nfp2zVqlU4ceIEioqK8Mknn6C2tharV692r1dXV4fm5mbs2rULBQUFUFUVjzzyCHbt2oVdu3bBZDK5w+TnP/85srKykJubi7KyMuTm5l5WR0FBAVpaWvDZZ5/hww8/xJYtW7Bp0yZ3+zfffIObbroJX331FebOnYslS5ZACAGr1Yrly5fj3XffRVlZGTZs2ID09HSdf2pEDBmiXhNC4E9/+hN+9atfwWw2IyYmBvPnz8f27dvdr5FlGT/96U9hNBoRERGBPn36YPz48YiMjERMTAwWLFiAkpISTe+nKAo+/vhj/OIXv0BMTAwGDBiAp556qsseSXJyMh577DEYDAY8/PDDOHfuHOrq6ty1HDlyBB0dHUhKSsKQIUOu7w+E6Ap4qX+iXmpoaEB7ezseeeQR9zIhhHv4CgD69OkDk8nkft7e3o5XX30VX3zxBZqbmwEAbW1t7kuo96SxsREOh6PLDdKSk5NRU1Pjfp6YmOh+HBkZCQCwWq3o27cv3njjDbz//vtYsmQJRo4ciX/7t39DWlpaL3tPpA1DhkgjSZK6PO+8V8v27dvRr18/Teu8//77OHHiBP70pz+hb9++qKiowLRp06DlOrV9+vRBeHg4qqqqMHjwYADf3yFRi+zsbGRnZ6OjowNvvvkmXn75Zfz3f/+3pnWJeovDZUQaJSQkoKamBna7HYBr+Gn69On4zW9+g/r6egBATU0Nvvjii26/R1tbG0wmE+Li4tDU1IS33nqrS3tiYmK358QYDAZMmDABb7zxBlpbW1FZWYl169ZhypQpV629rq4On332GaxWK4xGI6KioiDL/PMn/fG3jEijMWPGYPDgwfjBD36A0aNHAwAWL16MgQMH4rHHHsPIkSPx5JNP4sSJE91+jyeeeAI2mw1jxozBjBkzkJ2d3aX98ccfR3FxMUaNGoXly5dftv7LL7+MyMhI3H///Zg1axYmTZqEH/3oR1etXVVV/OEPf0B2djbuuOMOlJSUYNmyZZ79AIh6gfeTISIi3XBPhoiIdMOQISIi3TBkiIhINwwZIiLSDUOGiIh0w5AhIiLdMGSIiEg3DBkiItINQ4aIiHTz/wPmzTFiRR7HQAAAAABJRU5ErkJggg==\n" }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "XINkmCdVeXPr" }, "source": [ "## Maximum Likelihood II\n", "\n", "A colleague of yours suggest that the parameter $\\lambda$ must be itself dependent on the weather and other factors since people bike when its not raining. Assume that you model $\\lambda$ as \n", "\n", "$$\\lambda_i = \\exp(\\mathbf w^T \\mathbf x_i)$$\n", "\n", "where $\\mathbf x_i$ is one of the example features and $\\mathbf w$ is a set of parameters. \n", "\n", "Train the model with SGD with this assumption and compare the MSE of the predictions with the `Maximum Likelihood I` approach. \n", "\n", "You may want to use [this partial derivative of the log likelihood function](http://home.cc.umanitoba.ca/~godwinrt/7010/poissonregression.pdf)" ] }, { "cell_type": "markdown", "source": [ "${\\partial l \\over \\partial \\beta} = \\sum_{i=1}^n(y_i - exp(X'_i\\beta)X_i$" ], "metadata": { "id": "ni9r4QxeK4L3" } }, { "cell_type": "code", "source": [ "def s_gradient(x, y, w):\n", " y_est = np.exp(np.dot(x, w))\n", " res = (y - y_est) * x.T\n", "\n", " r = []\n", " for i in range(len(res)):\n", " r.append(res[i][0])\n", " return np.array(r)\n", "\n", "def s_gradient_descent(learning_rate, max_iter):\n", " w = np.array([0.0 for _ in range(4)])\n", " batch_size = 0.2\n", "\n", " for _ in range(max_iter):\n", " samples = df.sample(frac=batch_size)\n", " n = len(samples)\n", "\n", " x = np.hstack([\n", " np.array([1 for _ in range(n)]).reshape(n, 1),\n", " np.array(samples['HIGH_T']).reshape(n, 1),\n", " np.array(samples['LOW_T']).reshape(n, 1),\n", " np.array(samples['PRECIP']).reshape(n, 1)\n", " ])\n", " y = np.array(samples['BB_COUNT'])\n", " g = s_gradient(x, y, w)\n", " w = w - learning_rate * g\n", "\n", " return w" ], "metadata": { "id": "9e8m9A9jsrO2" }, "execution_count": 157, "outputs": [] }, { "cell_type": "code", "source": [ "iterations = 10\n", "alpha = 0.0001\n", "lam = 1000\n", "\n", "w = s_gradient_descent(alpha, iterations)\n", "# print(w)\n", "\n", "n = len(df)\n", "x = np.hstack([\n", " np.array([1 for _ in range(n)]).reshape(n, 1),\n", " np.array(df['HIGH_T']).reshape(n, 1),\n", " np.array(df['LOW_T']).reshape(n, 1),\n", " np.array(df['PRECIP']).reshape(n, 1)\n", "])\n", "l = np.exp(w * x.T)\n", "print(l)\n" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 223 }, "id": "FMOiaAn7A86p", "outputId": "f07affe7-c717-424e-c5ed-75c9f388ca00" }, "execution_count": 167, "outputs": [ { "output_type": "error", "ename": "ValueError", "evalue": "ignored", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0marray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdf\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'PRECIP'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mn\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m ])\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0ml\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mw\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mT\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ml\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mValueError\u001b[0m: operands could not be broadcast together with shapes (4,) (4,214) " ] } ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.9" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "7d6993cb2f9ce9a59d5d7380609d9cb5192a9dedd2735a011418ad9e827eb538" } }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 0 }