{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "UhZuaTM3YjJ-" }, "source": [ "# Midterm - Spring 2023\n", "\n", "## Problem 1: Take-at-home (45 points total)\n", "\n", "You are applying for a position at the data science team of USDA and you are given data associated with determining appropriate parasite treatment of canines. The suggested treatment options are determined based on a **logistic regression** model that predicts if the canine is infected with a parasite. \n", "\n", "The data is given in the site: https://data.world/ehales/grls-parasite-study/workspace/file?filename=CBC_data.csv and more specifically in the CBC_data.csv file. Login using you University Google account to access the data and the description that includes a paper on the study (**you dont need to read the paper to solve this problem**). Your target variable $y$ column is titled `parasite_status`. \n", "\n", "\n" ] }, { "cell_type": "markdown", "source": [ "- https://pantelis.github.io/artificial-intelligence/intro.html" ], "metadata": { "id": "AWSv8yOxyXYD" } }, { "cell_type": "markdown", "metadata": { "id": "1THcWuqiYjJ_" }, "source": [ "### Question 1 - Feature Engineering (5 points)\n", "\n", "Write the posterior probability expressions for logistic regression for the problem you are given to solve." ] }, { "cell_type": "markdown", "metadata": { "id": "MckwhLbUYjJ_" }, "source": [ "$$p(y=1| \\mathbf{x}, \\mathbf w)$$ \n", "\n", "$$p(y=0| \\mathbf{x}, \\mathbf w)$$ " ] }, { "cell_type": "markdown", "source": [ "$$p(y = 1|x, w) = \\sigma (\\mathbf{x^T}\\mathbf{w}) = {1 \\over 1 + e^{-x^Tw}}$$\n", "\n", "$$p(y = 0|x, w) = 1 - \\sigma (\\mathbf{x^T}\\mathbf{w}) = 1 -{1 \\over 1 + e^{-x^Tw}}$$\n", "\n" ], "metadata": { "id": "Dof11_sUofVi" } }, { "cell_type": "markdown", "metadata": { "id": "_cHO1w6HYjJ_" }, "source": [ "\n", "### Question 2 - Decision Boundary (5 points)\n", "\n", "Write the expression for the decision boundary assuming that $p(y=1)=p(y=0)$. The decision boundary is the line that separates the two classes." ] }, { "cell_type": "markdown", "metadata": { "id": "vKseaYyfYjKA" }, "source": [ "$$p(y=1) + p(y=0) = 1$$\n", "\n", "Linear decision function:\n", "\n", "$$f(x) = w * x + \\alpha$$\n", "\n", "Decision Boundary:\n", "\n", "$$H = \\{x : w * x = - \\alpha\\}$$\n", "\n", "$$y = 0.5$$" ] }, { "cell_type": "markdown", "metadata": { "id": "nGDtm1LWYjKA" }, "source": [ "\n", "\n", "### Question 3 - Loss function (5 points)\n", "\n", "Write the expression of the loss as a function of $\\mathbf w$ that makes sense for you to use in this problem. \n", "\n", "NOTE: The loss will be a function that will include this function: \n", "\n", "$$\\sigma(a) = \\frac{1}{1+e^{-a}}$$\n", "\n" ] }, { "cell_type": "markdown", "metadata": { "id": "ytoLWwasYjKA" }, "source": [ "$$a = w^Tx$$\n", "\n", "$$L_{CE} = - [\\sum_{i=1}^n \\{y_i \\ln(\\sigma(a)) + (1 - \\sigma(a)) \\ln(1 - \\sigma (a))\\}]$$" ] }, { "cell_type": "markdown", "metadata": { "id": "M_0ufZQtYjKA" }, "source": [ "\n", "### Question 4 - Gradient (5 points)\n", "\n", "Write the expression of the gradient of the loss with respect to the parameters - show all your work.\n", "\n" ] }, { "cell_type": "markdown", "source": [ "$${d \\over da} \\sigma (a) = {d \\over da} (\\frac{1}{1+e^{-a}})$$\n", "\n", "$$= {e^{-a} \\over (1 - e^{-a})^2}$$\n", "\n", "$$= {1 \\over 1 + e^{-a}} * {e^{-a} \\over 1 + e^{-a}}$$\n", "\n", "$$= {1 \\over 1 + e^{-a}} * {1 + e^{-a} - 1 \\over 1 + e^{-a}}$$\n", "\n", "$$= {1 \\over 1 + e^{-a}} * {1 + e^{-a} \\over 1 + e^{-a}} - {1 \\over 1 + e^{-a}}$$\n", "\n", "$$= \\sigma(a)(1 - \\sigma(a))$$\n", "\n" ], "metadata": { "id": "j7JwYBU5c2Oz" } }, { "cell_type": "markdown", "metadata": { "id": "mM9vu8WnYjKA" }, "source": [ "$$ \\nabla_\\mathbf w L_{CE} = \\sum_{i=1}^m (\\hat y_i - y_i)x_i$$\n" ] }, { "cell_type": "markdown", "metadata": { "id": "mbKlMmtMYjKB" }, "source": [ "### Question 5 - Imbalanced dataset (10 points)\n", "\n", "You are now told that in the dataset \n", "\n", "$$p(y=0) >> p(y=1)$$\n", "\n", "Can you comment if the accuracy of Logistic Regression will be affected by such imbalance?\n", "\n" ] }, { "cell_type": "code", "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "\n", "np.random.seed(0)\n", "sns.set_theme(style='whitegrid', palette='pastel')\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "\n", "# Need to manually import to execute\n", "\n", "df = pd.read_csv('CBC_data.csv')\n", "df.info()" ], "metadata": { "id": "c6XkzsjVR5cx", "colab": { "base_uri": "https://localhost:8080/" }, "outputId": "8535c7b6-c17e-42d3-ad3a-cfcad4e34f25" }, "execution_count": 537, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "\n", "RangeIndex: 3018 entries, 0 to 3017\n", "Data columns (total 15 columns):\n", " # Column Non-Null Count Dtype \n", "--- ------ -------------- ----- \n", " 0 ID 3018 non-null object \n", " 1 SEX 3018 non-null object \n", " 2 TYPEAREA 3018 non-null object \n", " 3 SEX.REPRO 3018 non-null object \n", " 4 REPRO.STATUS 3018 non-null object \n", " 5 AGE 3018 non-null int64 \n", " 6 PARASITE_STATUS 3018 non-null object \n", " 7 RBC 2995 non-null float64\n", " 8 HGB 2995 non-null float64\n", " 9 WBC 2996 non-null float64\n", " 10 EOS.CNT 2995 non-null float64\n", " 11 MONO.CNT 2995 non-null float64\n", " 12 NUT.CNT 2995 non-null float64\n", " 13 PL.CNT 2995 non-null float64\n", " 14 LYMP.CNT 2995 non-null float64\n", "dtypes: float64(8), int64(1), object(6)\n", "memory usage: 353.8+ KB\n" ] } ] }, { "cell_type": "code", "source": [ "def label_function(val):\n", " return f'{val / 100 * len(df):.0f}\\n{val:.0f}%'\n", "\n", "df.groupby('PARASITE_STATUS').size().plot(kind='pie', autopct=label_function)\n", "\n", "plt.ylabel('') \n", "plt.title('Parasite Status')\n", "plt.show()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 264 }, "id": "_ykAJfjMSEnm", "outputId": "77e9250b-b988-4cfd-b733-9b0b88321622" }, "execution_count": 538, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ] }, { "cell_type": "markdown", "metadata": { "id": "m0ME9LZGYjKB" }, "source": [ "The given problem is a binary classification problem with a dataset that is majority negative. Based on the graph above, 93% of data represent a 0 and 7% represent a 1.\n", "\n", "For a logistic regression model, the dataset may produce more false negatives rather than false positives. The weights calculated by the model will focus on negative results rather than positive.\n", "\n", "On a basic level if the model labeled all the data negative, it would have around a 93% accuracy based on the dataset. This would create a misleading result.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "6hzmRuBNYjKB" }, "source": [ "\n", "### Question 6 - SGD (15 points)\n", "\n", "The interviewer was impressed with your answers and wants to test your programming skills. \n", "\n", "1. Use the dataset to train a logistic regressor that will predict the target variable $y$. \n", "\n", " 2. Report the harmonic mean of precision (p) and recall (r) i.e the [metric called $F_1$ score](https://en.wikipedia.org/wiki/F-score) that is calculated as shown below using a test dataset that is 20% of each group. Plot the $F_1$ score vs the iteration number $t$. \n", "\n", "$$F_1 = \\frac{2}{r^{-1} + p^{-1}}$$\n", "\n", "Your code includes hyperparameter optimization of the learning rate and mini batch size. Please learn about cross validation which is a splitting strategy for tuning models [here](https://scikit-learn.org/stable/modules/cross_validation.html).\n", "\n", "You are allowed to use any library you want to code this problem.\n", "\n" ] }, { "cell_type": "code", "source": [ "from sklearn.metrics import f1_score\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.linear_model import LogisticRegression, SGDClassifier" ], "metadata": { "id": "d47OQOdkkAJN" }, "execution_count": 539, "outputs": [] }, { "cell_type": "code", "source": [ "df.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 270 }, "id": "6lCFhxjSX04P", "outputId": "e41f7d82-64c2-4d99-daaf-b7a01c163645" }, "execution_count": 540, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " ID SEX TYPEAREA SEX.REPRO REPRO.STATUS AGE \\\n", "0 grls5ZUT2BYY Male Suburban IntactMale Intact 9 \n", "1 grls8DCONYUU Female Rural NeuteredFemale Neutered 6 \n", "2 grlsUC5R4PTT Male Suburban IntactMale Intact 14 \n", "3 grlsXUR2PY88 Male Rural IntactMale Intact 6 \n", "4 grlsTBZUF3GG Female Rural IntactFemale Intact 18 \n", "\n", " PARASITE_STATUS RBC HGB WBC EOS.CNT MONO.CNT NUT.CNT PL.CNT \\\n", "0 Negative 6.4 16.6 14.2 142.0 852.0 6390.0 210.0 \n", "1 Negative 4.8 12.5 10.0 400.0 300.0 4800.0 209.0 \n", "2 Negative 6.2 17.3 9.5 190.0 475.0 7315.0 164.0 \n", "3 Negative 5.4 13.8 14.1 1692.0 423.0 7755.0 254.0 \n", "4 Negative 5.9 14.4 6.5 390.0 130.0 2795.0 213.0 \n", "\n", " LYMP.CNT \n", "0 6816.0 \n", "1 4500.0 \n", "2 1520.0 \n", "3 4230.0 \n", "4 3185.0 " ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDSEXTYPEAREASEX.REPROREPRO.STATUSAGEPARASITE_STATUSRBCHGBWBCEOS.CNTMONO.CNTNUT.CNTPL.CNTLYMP.CNT
0grls5ZUT2BYYMaleSuburbanIntactMaleIntact9Negative6.416.614.2142.0852.06390.0210.06816.0
1grls8DCONYUUFemaleRuralNeuteredFemaleNeutered6Negative4.812.510.0400.0300.04800.0209.04500.0
2grlsUC5R4PTTMaleSuburbanIntactMaleIntact14Negative6.217.39.5190.0475.07315.0164.01520.0
3grlsXUR2PY88MaleRuralIntactMaleIntact6Negative5.413.814.11692.0423.07755.0254.04230.0
4grlsTBZUF3GGFemaleRuralIntactFemaleIntact18Negative5.914.46.5390.0130.02795.0213.03185.0
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 540 } ] }, { "cell_type": "markdown", "source": [ "**Preprocessing**" ], "metadata": { "id": "Akl94gXapAzB" } }, { "cell_type": "code", "source": [ "# Remove rows with NaNs\n", "df = df.dropna()\n", "\n", "# All the columns replaced below have a sort of either binary or natural pattern\n", "# to them so manually replaced them inplace.\n", "#\n", "# Ex: Rural -> Suburban -> Urban has increasing population density, so mapped\n", "# to [0, 1, 2]\n", "\n", "df['SEX'].replace(['Male', 'Female'], [1, 0], inplace=True)\n", "df['REPRO.STATUS'].replace(['Intact', 'Neutered'], [1, 0], inplace=True)\n", "df['PARASITE_STATUS'].replace(['Negative', 'Positive'], [0, 1], inplace=True)\n", "df['TYPEAREA'].replace(['Rural', 'Suburban', 'Urban'], [0, 1, 2], inplace=True)\n", "\n", "\n", "# Undersampling\n", "# https://www.datasnips.com/63/undersampling-imbalanced-data-for-binary-classification/\n", "positive = df[df['PARASITE_STATUS'] == 1]\n", "negative = df[df['PARASITE_STATUS'] == 0]\n", "negative = negative.sample(n=len(positive), random_state=42)\n", "df = pd.concat([positive, negative], axis=0)\n", "\n", "\n", "# Removing `ID` since doesn't provide model relevant information\n", "# Removing `SEX.REPRO` because of already existing `SEX` and `REPRO.STATUS` columns\n", "try:\n", " df = df.drop(['ID', 'SEX.REPRO'], axis=1)\n", "except:\n", " pass\n", "\n", "\n", "# Shifting target variable to the front for readability\n", "cols = ['PARASITE_STATUS'] + [col for col in df if col != 'PARASITE_STATUS']\n", "df = df[cols]\n", "\n", "df.head()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 250 }, "id": "ARJrbriigHii", "outputId": "7d1d7b07-9cf4-41dc-e177-1b1d699babb3" }, "execution_count": 541, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ " PARASITE_STATUS SEX TYPEAREA REPRO.STATUS AGE RBC HGB WBC \\\n", "7 1 1 1 1 9 5.8 14.7 13.9 \n", "19 1 1 0 1 25 5.8 14.6 11.3 \n", "23 1 1 1 1 24 5.7 14.4 10.1 \n", "24 1 0 1 1 11 5.0 13.6 10.7 \n", "52 1 1 1 1 7 5.6 14.4 11.8 \n", "\n", " EOS.CNT MONO.CNT NUT.CNT PL.CNT LYMP.CNT \n", "7 139.0 417.0 7089.0 334.0 6255.0 \n", "19 0.0 1017.0 6667.0 183.0 3616.0 \n", "23 3131.0 404.0 3333.0 262.0 3232.0 \n", "24 1177.0 535.0 4922.0 318.0 4066.0 \n", "52 118.0 354.0 5664.0 319.0 5664.0 " ], "text/html": [ "\n", "
\n", "
\n", "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
PARASITE_STATUSSEXTYPEAREAREPRO.STATUSAGERBCHGBWBCEOS.CNTMONO.CNTNUT.CNTPL.CNTLYMP.CNT
7111195.814.713.9139.0417.07089.0334.06255.0
191101255.814.611.30.01017.06667.0183.03616.0
231111245.714.410.13131.0404.03333.0262.03232.0
241011115.013.610.71177.0535.04922.0318.04066.0
52111175.614.411.8118.0354.05664.0319.05664.0
\n", "
\n", " \n", " \n", " \n", "\n", " \n", "
\n", "
\n", " " ] }, "metadata": {}, "execution_count": 541 } ] }, { "cell_type": "markdown", "source": [ "**Modeling**" ], "metadata": { "id": "eJBjumyMpMl-" } }, { "cell_type": "code", "source": [ "X = df.drop(['PARASITE_STATUS'], axis=1)\n", "y = df['PARASITE_STATUS']\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)" ], "metadata": { "id": "NVyvAkQvjiir" }, "execution_count": 542, "outputs": [] }, { "cell_type": "code", "source": [ "iters = 1000\n", "\n", "clf = SGDClassifier(max_iter=iters, shuffle=True, random_state=42)\n", "\n", "f1_scores = []\n", "\n", "for _ in range(iters):\n", " clf.partial_fit(X_train, y_train, classes=[1, 0])\n", " y_pred = clf.predict(X_test)\n", " f1_scores.append(f1_score(y_test, y_pred))\n", "\n", "\n", "sns.lineplot(x=range(len(f1_scores)), y=f1_scores)\n", "plt.show()" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 270 }, "id": "n-_mpDTs6em-", "outputId": "4caba560-a959-4dcc-9cd4-be11264137d8" }, "execution_count": 543, "outputs": [ { "output_type": "display_data", "data": { "text/plain": [ "
" ], "image/png": "\n" }, "metadata": {} } ] }, { "cell_type": "markdown", "source": [ "**Tuning**" ], "metadata": { "id": "QuBneAp4Gck5" } }, { "cell_type": "code", "source": [ "param_grid = {\n", " 'learning_rate': ['constant', 'optimal', 'invscaling', 'adaptive'],\n", " # 'batch_size?': [16, 32, 64, 128]\n", "}\n", "\n", "sgd = SGDClassifier()\n", "\n", "grid_search = GridSearchCV(sgd, param_grid, cv=5, scoring='f1_macro')\n", "grid_search.fit(X_train, y_train)\n", "\n", "print(grid_search.best_score_)" ], "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "2EId9Mr9k11f", "outputId": "c7f2ae0b-8107-4774-e753-c49bb9fe22fd" }, "execution_count": 544, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "0.3760951785106917\n" ] } ] } ], "metadata": { "kernelspec": { "display_name": "ai-course", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.10.8" }, "orig_nbformat": 4, "vscode": { "interpreter": { "hash": "62556f7a043365a66e0918c892755cfafede529a87e97207556f006a109bade4" } }, "colab": { "provenance": [] } }, "nbformat": 4, "nbformat_minor": 0 }