{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "### TU257 - Lab4 - Demo - Naive Baye\n", "#### Introduction to simple Classification\n", "The examples in this notebook illustrate some of the simple steps needed for classification.\n", "\n", "It is important to remember all the things we have covered in the previous weeks, as all of those\n", "apply to every Classification problem.\n", "But firstly, we will start with some simiple examples.\n", "\n", "Work through the first example, examining every step/cell. Add addition annotations and descriptions where you can.\n", "\n", "For the second example, there are very few comments/annotations. The exercise for you is to add these." ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# load the iris dataset\n", "from sklearn.datasets import load_iris\n", "iris = load_iris()" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'data': array([[5.1, 3.5, 1.4, 0.2],\n", " [4.9, 3. , 1.4, 0.2],\n", " [4.7, 3.2, 1.3, 0.2],\n", " [4.6, 3.1, 1.5, 0.2],\n", " [5. , 3.6, 1.4, 0.2],\n", " [5.4, 3.9, 1.7, 0.4],\n", " [4.6, 3.4, 1.4, 0.3],\n", " [5. , 3.4, 1.5, 0.2],\n", " [4.4, 2.9, 1.4, 0.2],\n", " [4.9, 3.1, 1.5, 0.1],\n", " [5.4, 3.7, 1.5, 0.2],\n", " [4.8, 3.4, 1.6, 0.2],\n", " [4.8, 3. , 1.4, 0.1],\n", " [4.3, 3. , 1.1, 0.1],\n", " [5.8, 4. , 1.2, 0.2],\n", " [5.7, 4.4, 1.5, 0.4],\n", " [5.4, 3.9, 1.3, 0.4],\n", " [5.1, 3.5, 1.4, 0.3],\n", " [5.7, 3.8, 1.7, 0.3],\n", " [5.1, 3.8, 1.5, 0.3],\n", " [5.4, 3.4, 1.7, 0.2],\n", " [5.1, 3.7, 1.5, 0.4],\n", " [4.6, 3.6, 1. , 0.2],\n", " [5.1, 3.3, 1.7, 0.5],\n", " [4.8, 3.4, 1.9, 0.2],\n", " [5. , 3. , 1.6, 0.2],\n", " [5. , 3.4, 1.6, 0.4],\n", " [5.2, 3.5, 1.5, 0.2],\n", " [5.2, 3.4, 1.4, 0.2],\n", " [4.7, 3.2, 1.6, 0.2],\n", " [4.8, 3.1, 1.6, 0.2],\n", " [5.4, 3.4, 1.5, 0.4],\n", " [5.2, 4.1, 1.5, 0.1],\n", " [5.5, 4.2, 1.4, 0.2],\n", " [4.9, 3.1, 1.5, 0.2],\n", " [5. , 3.2, 1.2, 0.2],\n", " [5.5, 3.5, 1.3, 0.2],\n", " [4.9, 3.6, 1.4, 0.1],\n", " [4.4, 3. , 1.3, 0.2],\n", " [5.1, 3.4, 1.5, 0.2],\n", " [5. , 3.5, 1.3, 0.3],\n", " [4.5, 2.3, 1.3, 0.3],\n", " [4.4, 3.2, 1.3, 0.2],\n", " [5. , 3.5, 1.6, 0.6],\n", " [5.1, 3.8, 1.9, 0.4],\n", " [4.8, 3. , 1.4, 0.3],\n", " [5.1, 3.8, 1.6, 0.2],\n", " [4.6, 3.2, 1.4, 0.2],\n", " [5.3, 3.7, 1.5, 0.2],\n", " [5. , 3.3, 1.4, 0.2],\n", " [7. , 3.2, 4.7, 1.4],\n", " [6.4, 3.2, 4.5, 1.5],\n", " [6.9, 3.1, 4.9, 1.5],\n", " [5.5, 2.3, 4. , 1.3],\n", " [6.5, 2.8, 4.6, 1.5],\n", " [5.7, 2.8, 4.5, 1.3],\n", " [6.3, 3.3, 4.7, 1.6],\n", " [4.9, 2.4, 3.3, 1. ],\n", " [6.6, 2.9, 4.6, 1.3],\n", " [5.2, 2.7, 3.9, 1.4],\n", " [5. , 2. , 3.5, 1. ],\n", " [5.9, 3. , 4.2, 1.5],\n", " [6. , 2.2, 4. , 1. ],\n", " [6.1, 2.9, 4.7, 1.4],\n", " [5.6, 2.9, 3.6, 1.3],\n", " [6.7, 3.1, 4.4, 1.4],\n", " [5.6, 3. , 4.5, 1.5],\n", " [5.8, 2.7, 4.1, 1. ],\n", " [6.2, 2.2, 4.5, 1.5],\n", " [5.6, 2.5, 3.9, 1.1],\n", " [5.9, 3.2, 4.8, 1.8],\n", " [6.1, 2.8, 4. , 1.3],\n", " [6.3, 2.5, 4.9, 1.5],\n", " [6.1, 2.8, 4.7, 1.2],\n", " [6.4, 2.9, 4.3, 1.3],\n", " [6.6, 3. , 4.4, 1.4],\n", " [6.8, 2.8, 4.8, 1.4],\n", " [6.7, 3. , 5. , 1.7],\n", " [6. , 2.9, 4.5, 1.5],\n", " [5.7, 2.6, 3.5, 1. ],\n", " [5.5, 2.4, 3.8, 1.1],\n", " [5.5, 2.4, 3.7, 1. ],\n", " [5.8, 2.7, 3.9, 1.2],\n", " [6. , 2.7, 5.1, 1.6],\n", " [5.4, 3. , 4.5, 1.5],\n", " [6. , 3.4, 4.5, 1.6],\n", " [6.7, 3.1, 4.7, 1.5],\n", " [6.3, 2.3, 4.4, 1.3],\n", " [5.6, 3. , 4.1, 1.3],\n", " [5.5, 2.5, 4. , 1.3],\n", " [5.5, 2.6, 4.4, 1.2],\n", " [6.1, 3. , 4.6, 1.4],\n", " [5.8, 2.6, 4. , 1.2],\n", " [5. , 2.3, 3.3, 1. ],\n", " [5.6, 2.7, 4.2, 1.3],\n", " [5.7, 3. , 4.2, 1.2],\n", " [5.7, 2.9, 4.2, 1.3],\n", " [6.2, 2.9, 4.3, 1.3],\n", " [5.1, 2.5, 3. , 1.1],\n", " [5.7, 2.8, 4.1, 1.3],\n", " [6.3, 3.3, 6. , 2.5],\n", " [5.8, 2.7, 5.1, 1.9],\n", " [7.1, 3. , 5.9, 2.1],\n", " [6.3, 2.9, 5.6, 1.8],\n", " [6.5, 3. , 5.8, 2.2],\n", " [7.6, 3. , 6.6, 2.1],\n", " [4.9, 2.5, 4.5, 1.7],\n", " [7.3, 2.9, 6.3, 1.8],\n", " [6.7, 2.5, 5.8, 1.8],\n", " [7.2, 3.6, 6.1, 2.5],\n", " [6.5, 3.2, 5.1, 2. ],\n", " [6.4, 2.7, 5.3, 1.9],\n", " [6.8, 3. , 5.5, 2.1],\n", " [5.7, 2.5, 5. , 2. ],\n", " [5.8, 2.8, 5.1, 2.4],\n", " [6.4, 3.2, 5.3, 2.3],\n", " [6.5, 3. , 5.5, 1.8],\n", " [7.7, 3.8, 6.7, 2.2],\n", " [7.7, 2.6, 6.9, 2.3],\n", " [6. , 2.2, 5. , 1.5],\n", " [6.9, 3.2, 5.7, 2.3],\n", " [5.6, 2.8, 4.9, 2. ],\n", " [7.7, 2.8, 6.7, 2. ],\n", " [6.3, 2.7, 4.9, 1.8],\n", " [6.7, 3.3, 5.7, 2.1],\n", " [7.2, 3.2, 6. , 1.8],\n", " [6.2, 2.8, 4.8, 1.8],\n", " [6.1, 3. , 4.9, 1.8],\n", " [6.4, 2.8, 5.6, 2.1],\n", " [7.2, 3. , 5.8, 1.6],\n", " [7.4, 2.8, 6.1, 1.9],\n", " [7.9, 3.8, 6.4, 2. ],\n", " [6.4, 2.8, 5.6, 2.2],\n", " [6.3, 2.8, 5.1, 1.5],\n", " [6.1, 2.6, 5.6, 1.4],\n", " [7.7, 3. , 6.1, 2.3],\n", " [6.3, 3.4, 5.6, 2.4],\n", " [6.4, 3.1, 5.5, 1.8],\n", " [6. , 3. , 4.8, 1.8],\n", " [6.9, 3.1, 5.4, 2.1],\n", " [6.7, 3.1, 5.6, 2.4],\n", " [6.9, 3.1, 5.1, 2.3],\n", " [5.8, 2.7, 5.1, 1.9],\n", " [6.8, 3.2, 5.9, 2.3],\n", " [6.7, 3.3, 5.7, 2.5],\n", " [6.7, 3. , 5.2, 2.3],\n", " [6.3, 2.5, 5. , 1.9],\n", " [6.5, 3. , 5.2, 2. ],\n", " [6.2, 3.4, 5.4, 2.3],\n", " [5.9, 3. , 5.1, 1.8]]),\n", " 'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n", " 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,\n", " 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]),\n", " 'frame': None,\n", " 'target_names': array(['setosa', 'versicolor', 'virginica'], dtype='\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dayoutlooktemphumiditywindplay
0D1SunnyHotHighWeakNo
1D2SunnyHotHighStrongNo
2D3OvercastHotHighWeakYes
3D4RainMildHighWeakYes
4D5RainCoolNormalWeakYes
5D6RainCoolNormalStrongNo
6D7OvercastCoolNormalStrongYes
7D8SunnyMildHighWeakNo
8D9SunnyCoolNormalWeakYes
9D10RainMildNormalWeakYes
10D11SunnyMildNormalStrongYes
11D12OvercastMildHighStrongYes
12D13OvercastHotNormalWeakYes
13D14RainMildHighStrongNo
\n", "" ], "text/plain": [ " day outlook temp humidity wind play\n", "0 D1 Sunny Hot High Weak No\n", "1 D2 Sunny Hot High Strong No\n", "2 D3 Overcast Hot High Weak Yes\n", "3 D4 Rain Mild High Weak Yes\n", "4 D5 Rain Cool Normal Weak Yes\n", "5 D6 Rain Cool Normal Strong No\n", "6 D7 Overcast Cool Normal Strong Yes\n", "7 D8 Sunny Mild High Weak No\n", "8 D9 Sunny Cool Normal Weak Yes\n", "9 D10 Rain Mild Normal Weak Yes\n", "10 D11 Sunny Mild Normal Strong Yes\n", "11 D12 Overcast Mild High Strong Yes\n", "12 D13 Overcast Hot Normal Weak Yes\n", "13 D14 Rain Mild High Strong No" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "#Load in the dataset\n", "df = pd.read_csv('/Users/brendan.tierney/Dropbox/4-Datasets/play_tennis.csv')\n", "df.head(20)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(14, 6)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Number of row and features\n", "df.shape" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dayoutlooktemphumiditywindplay
0021010
1621000
2701011
3812011
4910111
51010100
61100101
71222010
81320111
9112111
10222101
11302001
12401111
13512000
\n", "
" ], "text/plain": [ " day outlook temp humidity wind play\n", "0 0 2 1 0 1 0\n", "1 6 2 1 0 0 0\n", "2 7 0 1 0 1 1\n", "3 8 1 2 0 1 1\n", "4 9 1 0 1 1 1\n", "5 10 1 0 1 0 0\n", "6 11 0 0 1 0 1\n", "7 12 2 2 0 1 0\n", "8 13 2 0 1 1 1\n", "9 1 1 2 1 1 1\n", "10 2 2 2 1 0 1\n", "11 3 0 2 0 0 1\n", "12 4 0 1 1 1 1\n", "13 5 1 2 0 0 0" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Transform the data in Text form\n", "#Use Encoding to do this quickly\n", "from sklearn import preprocessing \n", "\n", "#Setup the Label Encoder\n", "le=preprocessing.LabelEncoder()\n", "#Loop through the columns\n", "for col in df.columns:\n", " #transform the column\n", " df[col]=le.fit_transform(df[col])\n", "\n", "#Display the updated dataframe\n", "df" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dayoutlooktemphumiditywind
002101
162100
270101
381201
491011
\n", "
" ], "text/plain": [ " day outlook temp humidity wind\n", "0 0 2 1 0 1\n", "1 6 2 1 0 0\n", "2 7 0 1 0 1\n", "3 8 1 2 0 1\n", "4 9 1 0 1 1" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#There are various ways to do the next step\n", "#Here we separate the descriptive features from the Target features\n", "X = df.drop('play', axis=1)\n", "\n", "#Create a new dataframe (Y) to only contain the Target attribute\n", "Y = df['play']\n", "X.head()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "# use Stratified sampling to divide the data\n", "from sklearn.model_selection import StratifiedShuffleSplit\n", "\n", "#Setup the sampling and splitting the data\n", "split = StratifiedShuffleSplit(n_splits=10, test_size = 0.2, random_state=18)\n", "\n", "#Create the Train and Test datasets\n", "for train_index, test_index in split.split(X, Y):\n", " train_set = df.loc[train_index]\n", " test_set = df.loc[test_index]" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "#Exercise: Modify the above cell to replace X, Y in the split function with the dataframe subsets\n" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(11, 6)\n", "(3, 6)\n" ] } ], "source": [ "#display the sizes of the Train and Test datasets\n", "print(train_set.shape)\n", "print(test_set.shape)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dayoutlooktemphumiditywindplay
10222101
4910111
71222010
3812011
11302001
9112111
13512000
0021010
1621000
61100101
81320111
\n", "
" ], "text/plain": [ " day outlook temp humidity wind play\n", "10 2 2 2 1 0 1\n", "4 9 1 0 1 1 1\n", "7 12 2 2 0 1 0\n", "3 8 1 2 0 1 1\n", "11 3 0 2 0 0 1\n", "9 1 1 2 1 1 1\n", "13 5 1 2 0 0 0\n", "0 0 2 1 0 1 0\n", "1 6 2 1 0 0 0\n", "6 11 0 0 1 0 1\n", "8 13 2 0 1 1 1" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_set" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
dayoutlooktemphumiditywindplay
12401111
2701011
51010100
\n", "
" ], "text/plain": [ " day outlook temp humidity wind play\n", "12 4 0 1 1 1 1\n", "2 7 0 1 0 1 1\n", "5 10 1 0 1 0 0" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#This is a very small dataset\n", "test_set" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1 7\n", "0 4\n", "Name: play, dtype: int64" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_set['play'].value_counts()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1 2\n", "0 1\n", "Name: play, dtype: int64" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_set['play'].value_counts()" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "#Separate the datasets into X, Y\n", "X_train = train_set.drop('play', axis=1)\n", "X_test = test_set.drop('play', axis=1)\n", "Y_train = train_set['play']\n", "Y_test = test_set['play']" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "#Import Gaussian Naive Bayes model\n", "from sklearn.naive_bayes import GaussianNB\n", "\n", "#Create a Gaussian Classifier\n", "model = GaussianNB()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "GaussianNB()" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Train (fit) the model using the training sets\n", "model.fit(X_train,Y_train)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [], "source": [ "#Apply the model to the Test dataset to create the Predicted values\n", "predicted= model.predict(X_test) " ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1, 0, 1])" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "predicted" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
play01
row_0
001
111
\n", "
" ], "text/plain": [ "play 0 1\n", "row_0 \n", "0 0 1\n", "1 1 1" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#Crosstab the columns to give a basic confusion matrix\n", "pd.crosstab(predicted, Y_test)" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [], "source": [ "#We can also get the prediction probabilities\n", "#These indicate how strong of a prediction the model made for each prediction\n", "#1=a very strong prediction\n", "#0=a very weak prediction\n", "Y_predict_prob = model.predict_proba(X_test)[:,1]" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([1. , 0.1098402, 1. ])" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "Y_predict_prob" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import seaborn as sns\n", "from sklearn.metrics import confusion_matrix\n", "import matplotlib.pyplot as plt\n", "\n", "# passing actual and predicted values\n", "cm = confusion_matrix(Y_test, predicted)\n", "\n", "# true write data values in each cell of the matrix\n", "sns.heatmap(cm, annot=True)\n", "plt.title('Confusion Matrix', fontsize = 15) \n", "plt.xlabel('Predicted', fontsize = 13) \n", "plt.ylabel('Acuals', fontsize = 13) \n", "\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.3" } }, "nbformat": 4, "nbformat_minor": 2 }