{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "path_data = '../../data/'\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import math\n", "import scipy.stats as stats\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "plt.style.use('fivethirtyeight')\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "def standard_units(x):\n", " return (x - np.mean(x))/np.std(x)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "def distance(point1, point2):\n", " \"\"\"The distance between two arrays of numbers.\"\"\"\n", " return np.sqrt(np.sum((point1 - point2)**2))\n", "\n", "def all_distances(training, point):\n", " \"\"\"The distance between p (an array of numbers) and the numbers in row i of attribute_table.\"\"\"\n", " attributes = training.drop(columns=['Class'])\n", " def distance_from_point(row):\n", " return distance(point, np.array(row))\n", " return attributes.apply(distance_from_point, axis=1)\n", "\n", "def table_with_distances(training, point):\n", " \"\"\"A copy of the training table with the distance from each row to array p.\"\"\"\n", " training1 = training.copy()\n", " training1['Distance'] = all_distances(training1, point)\n", " return training1\n", "\n", "def closest(training, point, k):\n", " \"\"\"A table containing the k closest rows in the training table to array p.\"\"\"\n", " with_dists = table_with_distances(training, point)\n", " sorted_by_distance = with_dists.sort_values(by=['Distance'])\n", " topk = sorted_by_distance.take(np.arange(k))\n", " return topk\n", "\n", "def majority(topkclasses):\n", " \"\"\"1 if the majority of the \"Class\" column is 1s, and 0 otherwise.\"\"\" \n", " ones = len(topkclasses[topkclasses['Class'] == 1])\n", " zeros = len(topkclasses[topkclasses['Class'] == 0])\n", " if ones > zeros:\n", " return 1\n", " else:\n", " return 0\n", "\n", "def classify(training, p, k):\n", " \"\"\"Classify an example with attributes p using k-nearest neighbor classification with the given training table.\"\"\"\n", " closestk = closest(training, p, k)\n", " topkclasses = closestk[['Class']]\n", " return majority(closestk)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Nearest Neighbors\n", "In this section we'll develop the *nearest neighbor* method of classification. Just focus on the ideas for now and don't worry if some of the code is mysterious. Later in the chapter we'll see how to organize our ideas into code that performs the classification." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Chronic kidney disease\n", "Let's work through an example. We're going to work with a data set that was collected to help doctors diagnose chronic kidney disease (CKD). Each row in the data set represents a single patient who was treated in the past and whose diagnosis is known. For each patient, we have a bunch of measurements from a blood test. We'd like to find which measurements are most useful for diagnosing CKD, and develop a way to classify future patients as \"has CKD\" or \"doesn't have CKD\" based on their blood test results." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
AgeBlood PressureSpecific GravityAlbuminSugarRed Blood CellsPus CellPus Cell clumpsBacteriaGlucose...Packed Cell VolumeWhite Blood Cell CountRed Blood Cell CountHypertensionDiabetes MellitusCoronary Artery DiseaseAppetitePedal EdemaAnemiaClass
048701.00540normalabnormalpresentnotpresent117...3267003.9yesnonopooryesyes1
153901.02020abnormalabnormalpresentnotpresent70...29121003.7yesyesnopoornoyes1
263701.01030abnormalabnormalpresentnotpresent380...3245003.8yesyesnopooryesno1
368801.01032normalabnormalpresentpresent157...16110002.6yesyesyespooryesno1
461801.01520abnormalabnormalnotpresentnotpresent173...2492003.2yesyesyespooryesyes1
..................................................................
15355801.02000normalnormalnotpresentnotpresent140...4767004.9nononogoodnono0
15442701.02500normalnormalnotpresentnotpresent75...5478006.2nononogoodnono0
15512801.02000normalnormalnotpresentnotpresent100...4966005.4nononogoodnono0
15617601.02500normalnormalnotpresentnotpresent114...5172005.9nononogoodnono0
15758801.02500normalnormalnotpresentnotpresent131...5368006.1nononogoodnono0
\n", "

158 rows × 25 columns

\n", "
" ], "text/plain": [ " Age Blood Pressure Specific Gravity Albumin Sugar Red Blood Cells \\\n", "0 48 70 1.005 4 0 normal \n", "1 53 90 1.020 2 0 abnormal \n", "2 63 70 1.010 3 0 abnormal \n", "3 68 80 1.010 3 2 normal \n", "4 61 80 1.015 2 0 abnormal \n", ".. ... ... ... ... ... ... \n", "153 55 80 1.020 0 0 normal \n", "154 42 70 1.025 0 0 normal \n", "155 12 80 1.020 0 0 normal \n", "156 17 60 1.025 0 0 normal \n", "157 58 80 1.025 0 0 normal \n", "\n", " Pus Cell Pus Cell clumps Bacteria Glucose ... Packed Cell Volume \\\n", "0 abnormal present notpresent 117 ... 32 \n", "1 abnormal present notpresent 70 ... 29 \n", "2 abnormal present notpresent 380 ... 32 \n", "3 abnormal present present 157 ... 16 \n", "4 abnormal notpresent notpresent 173 ... 24 \n", ".. ... ... ... ... ... ... \n", "153 normal notpresent notpresent 140 ... 47 \n", "154 normal notpresent notpresent 75 ... 54 \n", "155 normal notpresent notpresent 100 ... 49 \n", "156 normal notpresent notpresent 114 ... 51 \n", "157 normal notpresent notpresent 131 ... 53 \n", "\n", " White Blood Cell Count Red Blood Cell Count Hypertension \\\n", "0 6700 3.9 yes \n", "1 12100 3.7 yes \n", "2 4500 3.8 yes \n", "3 11000 2.6 yes \n", "4 9200 3.2 yes \n", ".. ... ... ... \n", "153 6700 4.9 no \n", "154 7800 6.2 no \n", "155 6600 5.4 no \n", "156 7200 5.9 no \n", "157 6800 6.1 no \n", "\n", " Diabetes Mellitus Coronary Artery Disease Appetite Pedal Edema Anemia \\\n", "0 no no poor yes yes \n", "1 yes no poor no yes \n", "2 yes no poor yes no \n", "3 yes yes poor yes no \n", "4 yes yes poor yes yes \n", ".. ... ... ... ... ... \n", "153 no no good no no \n", "154 no no good no no \n", "155 no no good no no \n", "156 no no good no no \n", "157 no no good no no \n", "\n", " Class \n", "0 1 \n", "1 1 \n", "2 1 \n", "3 1 \n", "4 1 \n", ".. ... \n", "153 0 \n", "154 0 \n", "155 0 \n", "156 0 \n", "157 0 \n", "\n", "[158 rows x 25 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ckd = pd.read_csv(path_data + 'ckd.csv')\n", "ckd.rename(columns={'Blood Glucose Random':'Glucose'}, inplace=True)\n", "ckd" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some of the variables are categorical (words like \"abnormal\"), and some quantitative. The quantitative variables all have different scales. We're going to want to make comparisons and estimate distances, often by eye, so let's select just a few of the variables and work in standard units. Then we won't have to worry about the scale of each of the different variables." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HemoglobinGlucoseWhite Blood Cell CountClass
0-0.865744-0.221549-0.5697681
1-1.457446-0.9475971.1626841
2-1.0049683.841231-1.2755821
3-2.8148790.3963640.8097771
4-2.0839540.6435290.2322931
...............
1530.7005260.133751-0.5697680
1540.978974-0.870358-0.2168610
1550.735332-0.484162-0.6018500
1560.178436-0.267893-0.4093560
1570.735332-0.005280-0.5376860
\n", "

158 rows × 4 columns

\n", "
" ], "text/plain": [ " Hemoglobin Glucose White Blood Cell Count Class\n", "0 -0.865744 -0.221549 -0.569768 1\n", "1 -1.457446 -0.947597 1.162684 1\n", "2 -1.004968 3.841231 -1.275582 1\n", "3 -2.814879 0.396364 0.809777 1\n", "4 -2.083954 0.643529 0.232293 1\n", ".. ... ... ... ...\n", "153 0.700526 0.133751 -0.569768 0\n", "154 0.978974 -0.870358 -0.216861 0\n", "155 0.735332 -0.484162 -0.601850 0\n", "156 0.178436 -0.267893 -0.409356 0\n", "157 0.735332 -0.005280 -0.537686 0\n", "\n", "[158 rows x 4 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ckd_su = pd.DataFrame({'Hemoglobin':standard_units(ckd['Hemoglobin']), \n", " 'Glucose':standard_units(ckd['Glucose']), \n", " 'White Blood Cell Count':standard_units(ckd['White Blood Cell Count']), \n", " 'Class':ckd['Class'].astype(str)})\n", "ckd_su" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Let's look at two columns in particular: the hemoglobin level (in the patient's blood), and the blood glucose level (at a random time in the day; without fasting specially for the blood test). \n", "\n", "We'll draw a scatter plot to visualize the relation between the two variables. Blue dots are patients with CKD; gold dots are patients without CKD. What kind of medical test results seem to indicate CKD? \n", "\n", "Previously, to create a df containing the required columns we have used the `join` function, in this example we will use the `merge` function.\n", "\n", "[pandas.merge](https://pandas.pydata.org/pandas-docs/stable/user_guide/merging.html#database-style-dataframe-or-named-series-joining-merging)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HemoglobinGlucoseWhite Blood Cell CountClassColor
0-0.865744-0.221549-0.5697681darkblue
1-1.457446-0.9475971.1626841darkblue
2-1.0049683.841231-1.2755821darkblue
3-2.8148790.3963640.8097771darkblue
4-2.0839540.6435290.2322931darkblue
..................
1530.7005260.133751-0.5697680gold
1540.978974-0.870358-0.2168610gold
1550.735332-0.484162-0.6018500gold
1560.178436-0.267893-0.4093560gold
1570.735332-0.005280-0.5376860gold
\n", "

158 rows × 5 columns

\n", "
" ], "text/plain": [ " Hemoglobin Glucose White Blood Cell Count Class Color\n", "0 -0.865744 -0.221549 -0.569768 1 darkblue\n", "1 -1.457446 -0.947597 1.162684 1 darkblue\n", "2 -1.004968 3.841231 -1.275582 1 darkblue\n", "3 -2.814879 0.396364 0.809777 1 darkblue\n", "4 -2.083954 0.643529 0.232293 1 darkblue\n", ".. ... ... ... ... ...\n", "153 0.700526 0.133751 -0.569768 0 gold\n", "154 0.978974 -0.870358 -0.216861 0 gold\n", "155 0.735332 -0.484162 -0.601850 0 gold\n", "156 0.178436 -0.267893 -0.409356 0 gold\n", "157 0.735332 -0.005280 -0.537686 0 gold\n", "\n", "[158 rows x 5 columns]" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "color_table = pd.DataFrame(\n", " {'Class':np.array([1, 0]),\n", " 'Color':np.array(['darkblue', 'gold'])}, index=np.array([1,0]))\n", " \n", "color_table['Class'] = color_table['Class'].astype(str)\n", "\n", "ckd_combined = pd.merge(ckd_su, color_table, on='Class')\n", "\n", "ckd_combined" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "glucose_color_darkblue = ckd_combined[ckd_combined['Color'] == 'darkblue']\n", "glucose_color_gold = ckd_combined[ckd_combined['Color'] == 'gold']\n", "\n", "\n", "fig, ax = plt.subplots(figsize=(7,6))\n", "\n", "ax.scatter(glucose_color_darkblue['Hemoglobin'], \n", " glucose_color_darkblue['Glucose'], \n", " label='Color=darkblue', \n", " color='darkblue')\n", "\n", "ax.scatter(glucose_color_gold['Hemoglobin'], \n", " glucose_color_gold['Glucose'], \n", " label='Color=gold', \n", " color='gold')\n", "\n", "x_label = 'Hemoglobin'\n", "\n", "y_label = 'Glucose'\n", "\n", "y_vals = ax.get_yticks()\n", "\n", "plt.ylabel(y_label)\n", "\n", "ax.legend(bbox_to_anchor=(1.04,1), loc=\"upper left\")\n", "\n", "plt.xlabel(x_label)\n", "\n", "plt.show()\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Suppose Alice is a new patient who is not in the data set. If I tell you Alice's hemoglobin level and blood glucose level, could you predict whether she has CKD? It sure looks like it! You can see a very clear pattern here: points in the lower-right tend to represent people who don't have CKD, and the rest tend to be folks with CKD. To a human, the pattern is obvious. But how can we program a computer to automatically detect patterns such as this one?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## A Nearest Neighbor Classifier\n", "There are lots of kinds of patterns one might look for, and lots of algorithms for classification. But I'm going to tell you about one that turns out to be surprisingly effective. It is called *nearest neighbor classification*. Here's the idea. If we have Alice's hemoglobin and glucose numbers, we can put her somewhere on this scatterplot; the hemoglobin is her x-coordinate, and the glucose is her y-coordinate. Now, to predict whether she has CKD or not, we find the nearest point in the scatterplot and check whether it is blue or gold; we predict that Alice should receive the same diagnosis as that patient.\n", "\n", "In other words, to classify Alice as CKD or not, we find the patient in the training set who is \"nearest\" to Alice, and then use that patient's diagnosis as our prediction for Alice. The intuition is that if two points are near each other in the scatterplot, then the corresponding measurements are pretty similar, so we might expect them to receive the same diagnosis (more likely than not). We don't know Alice's diagnosis, but we do know the diagnosis of all the patients in the training set, so we find the patient in the training set who is most similar to Alice, and use that patient's diagnosis to predict Alice's diagnosis." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In the graph below, the red dot represents Alice. It is joined with a black line to the point that is nearest to it – its *nearest neighbor* in the training set. The figure is drawn by a function called `show_closest`. It takes an array that represents the $x$ and $y$ coordinates of Alice's point. Vary those to see how the closest point changes! Note especially when the closest point is blue and when it is gold." ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "def show_closest(point):\n", " \"\"\"point = array([x,y]) \n", " gives the coordinates of a new point\n", " shown in red\"\"\"\n", " \n", " HemoG1 = ckd_combined.copy()\n", " HemoG1 = HemoG1.drop(columns=['White Blood Cell Count', 'Color'])\n", " \n", " t = closest(HemoG1, point, 1)\n", " x_closest = t.iloc[0,0]\n", " y_closest = t.iloc[0,1]\n", "\n", " fig, ax = plt.subplots(figsize=(7,6))\n", " ax.scatter(glucose_color_darkblue['Hemoglobin'], \n", " glucose_color_darkblue['Glucose'], \n", " label='Color=darkblue', \n", " color='darkblue')\n", " ax.scatter(glucose_color_gold['Hemoglobin'], \n", " glucose_color_gold['Glucose'], \n", " label='Color=gold', \n", " color='gold')\n", " x_label = 'Hemoglobin'\n", " y_label = 'Glucose'\n", " y_vals = ax.get_yticks()\n", " plt.ylabel(y_label)\n", " ax.legend(bbox_to_anchor=(1.04,1), loc=\"upper left\")\n", " plt.xlabel(x_label)\n", " \n", " ax.scatter(point.item(0), point.item(1), color='red', s=30)\n", " ax.plot(np.array([point.item(0), x_closest]), np.array([point.item(1), y_closest]), color='k', lw=2)\n", " \n", " plt.show()" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# In this example, Alice's Hemoglobin attribute is 0 and her Glucose is 1.5.\n", "alice = np.array([0, 1.5])\n", "show_closest(alice)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Thus our *nearest neighbor classifier* works like this:\n", "- Find the point in the training set that is nearest to the new point.\n", "- If that nearest point is a \"CKD\" point, classify the new point as \"CKD\". If the nearest point is a \"not CKD\" point, classify the new point as \"not CKD\".\n", "\n", "The scatterplot suggests that this nearest neighbor classifier should be pretty accurate. Points in the lower-right will tend to receive a \"no CKD\" diagnosis, as their nearest neighbor will be a gold point. The rest of the points will tend to receive a \"CKD\" diagnosis, as their nearest neighbor will be a blue point. So the nearest neighbor strategy seems to capture our intuition pretty well, for this example." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Decision boundary\n", "Sometimes a helpful way to visualize a classifier is to map out the kinds of attributes where the classifier would predict 'CKD', and the kinds where it would predict 'not CKD'. We end up with some boundary between the two, where points on one side of the boundary will be classified 'CKD' and points on the other side will be classified 'not CKD'. This boundary is called the *decision boundary*. Each different classifier will have a different decision boundary; the decision boundary is just a way to visualize what criteria the classifier is using to classify points.\n", "\n", "For example, suppose the coordinates of Alice's point are (0, 1.5). Notice that the nearest neighbor is blue. Now try reducing the height (the $y$-coordinate) of the point. You'll see that at around $y = 0.95$ the nearest neighbor turns from blue to gold." ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "alice = np.array([0, 0.97])\n", "show_closest(alice)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Here are hundreds of new unclassified points, all in red." ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "tags": [ "remove_input" ] }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HemoglobinGlucose
0-2.0-2.0
1-2.0-1.9
2-2.0-1.8
3-2.0-1.7
4-2.0-1.6
.........
16762.01.6
16772.01.7
16782.01.8
16792.01.9
16802.02.0
\n", "

1681 rows × 2 columns

\n", "
" ], "text/plain": [ " Hemoglobin Glucose\n", "0 -2.0 -2.0\n", "1 -2.0 -1.9\n", "2 -2.0 -1.8\n", "3 -2.0 -1.7\n", "4 -2.0 -1.6\n", "... ... ...\n", "1676 2.0 1.6\n", "1677 2.0 1.7\n", "1678 2.0 1.8\n", "1679 2.0 1.9\n", "1680 2.0 2.0\n", "\n", "[1681 rows x 2 columns]" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_array = np.array([])\n", "y_array = np.array([])\n", "for x in np.arange(-2, 2.1, 0.1):\n", " for y in np.arange(-2, 2.1, 0.1):\n", " x_array = np.append(x_array, x)\n", " y_array = np.append(y_array, y)\n", " \n", "test_grid = pd.DataFrame(\n", " {'Hemoglobin':x_array,\n", " 'Glucose':y_array}\n", ")\n", "\n", "test_grid" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "tags": [ "remove_input" ] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "test_grid.plot.scatter('Hemoglobin', 'Glucose', color='red', figsize=(6,6), alpha=0.4, s=30)\n", "\n", "plt.scatter(ckd_combined['Hemoglobin'], ckd_combined['Glucose'], c=ckd_combined['Color'], edgecolor='k')\n", "\n", "plt.xlim(-2, 2)\n", "plt.ylim(-2, 2);" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Each of the red points has a nearest neighbor in the training set (the same blue and gold points as before). For some red points you can easily tell whether the nearest neighbor is blue or gold. For others, it's a little more tricky to make the decision by eye. Those are the points near the decision boundary.\n", "\n", "But the computer can easily determine the nearest neighbor of each point. So let's get it to apply our nearest neighbor classifier to each of the red points: \n", "\n", "For each red point, it must find the closest point in the training set; it must then change the color of the red point to become the color of the nearest neighbor. \n", "\n", "The resulting graph shows which points will get classified as 'CKD' (all the blue ones), and which as 'not CKD' (all the gold ones)." ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "def classify_grid(training, test, k):\n", " #print(training, test, k)\n", " #c = np.array([])\n", " ckd_new1 = pd.DataFrame(columns=['Hemoglobin', 'Glucose', 'Class', 'Distance'])\n", " empty = np.array([])\n", " for i in range(len(test)):\n", " # Run the classifier on the ith patient in the test set\n", " \n", " ckd_new2 = closest(training, np.array([test.iloc[i]]), k)\n", " #topkclasses = ckd_new2['Class']\n", "\n", " ones = len(ckd_new2[ckd_new2['Class'] == '1'])\n", " zeros = len(ckd_new2[ckd_new2['Class'] == '0'])\n", " \n", " if ones > zeros:\n", " #return 1\n", " empty = np.append(empty, 1)\n", " else:\n", " #return 0\n", " empty = np.append(empty, 0)\n", " \n", " return empty\n" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "tags": [ "remove_input" ] }, "outputs": [ { "data": { "text/plain": [ "array([1., 1., 1., ..., 1., 1., 1.])" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ckd_new = classify_grid(ckd_combined.drop(columns=['White Blood Cell Count', 'Color']), test_grid, 1)\n", "\n", "ckd_new" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HemoglobinGlucoseClass
0-2.0-2.01
1-2.0-1.91
2-2.0-1.81
3-2.0-1.71
4-2.0-1.61
\n", "
" ], "text/plain": [ " Hemoglobin Glucose Class\n", "0 -2.0 -2.0 1\n", "1 -2.0 -1.9 1\n", "2 -2.0 -1.8 1\n", "3 -2.0 -1.7 1\n", "4 -2.0 -1.6 1" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_grid['Class'] = ckd_new.astype(int)\n", "test_grid['Class'] = test_grid['Class'].astype(str)\n", "\n", "test_grid.head(5)" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
HemoglobinGlucoseClassColor
0-2.0-2.01darkblue
1-2.0-1.91darkblue
2-2.0-1.81darkblue
3-2.0-1.71darkblue
4-2.0-1.61darkblue
...............
16762.00.60gold
16772.00.70gold
16782.00.80gold
16792.00.90gold
16802.01.00gold
\n", "

1681 rows × 4 columns

\n", "
" ], "text/plain": [ " Hemoglobin Glucose Class Color\n", "0 -2.0 -2.0 1 darkblue\n", "1 -2.0 -1.9 1 darkblue\n", "2 -2.0 -1.8 1 darkblue\n", "3 -2.0 -1.7 1 darkblue\n", "4 -2.0 -1.6 1 darkblue\n", "... ... ... ... ...\n", "1676 2.0 0.6 0 gold\n", "1677 2.0 0.7 0 gold\n", "1678 2.0 0.8 0 gold\n", "1679 2.0 0.9 0 gold\n", "1680 2.0 1.0 0 gold\n", "\n", "[1681 rows x 4 columns]" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "test_grid1 = pd.DataFrame({'Hemoglobin':test_grid['Hemoglobin'], \n", " 'Glucose':test_grid['Glucose'], \n", " 'Class':test_grid['Class']})\n", "\n", "\n", "test_grid2 = pd.merge(test_grid1, color_table, on='Class')\n", "\n", "test_grid2" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "tags": [ "remove_input" ] }, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "test_grid3 = test_grid2[test_grid2['Color'] == 'darkblue']\n", "test_grid4 = test_grid2[test_grid2['Color'] == 'gold']\n", "\n", "plt.scatter(test_grid3['Hemoglobin'], test_grid3['Glucose'], color='darkblue', alpha=0.4, s=30)\n", "plt.scatter(test_grid4['Hemoglobin'], test_grid4['Glucose'], color='gold', alpha=0.4, s=30)\n", "\n", "plt.scatter(ckd_combined['Hemoglobin'], ckd_combined['Glucose'], c=ckd_combined['Color'], edgecolor='k')\n", "\n", "plt.xlim(-2, 2)\n", "plt.ylim(-2, 2);\n" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "test_grid5 = ckd_combined[ckd_combined['Color'] == 'darkblue']\n", "test_grid6 = ckd_combined[ckd_combined['Color'] == 'gold']\n", "\n", "fig, ax = plt.subplots(figsize=(7,6))\n", "\n", "ax.scatter(test_grid3['Hemoglobin'], \n", " test_grid3['Glucose'], \n", " color='darkblue', alpha=0.4, s=30)\n", "\n", "ax.scatter(test_grid4['Hemoglobin'], \n", " test_grid4['Glucose'], \n", " color='gold', alpha=0.4, s=30)\n", "\n", "ax.scatter(test_grid5['Hemoglobin'], \n", " test_grid5['Glucose'], \n", " color='darkblue', label='Color=darkblue', ec='darkblue', s=30)\n", "\n", "ax.scatter(test_grid6['Hemoglobin'], \n", " test_grid6['Glucose'], \n", " color='gold', label='Color=gold', ec='darkblue', s=30)\n", "\n", "x_label = 'Hemoglobin'\n", "\n", "y_label = 'Glucose'\n", "\n", "plt.xlabel(x_label)\n", "\n", "plt.ylabel(y_label)\n", "\n", "ax.legend(bbox_to_anchor=(1.04,1), loc=\"upper left\")\n", "\n", "plt.xlabel(x_label)\n", "\n", "plt.xlim(-2, 2)\n", "\n", "plt.ylim(-2, 2)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The decision boundary is where the classifier switches from turning the red points blue to turning them gold." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### k-Nearest Neighbors\n", "However, the separation between the two classes won't always be quite so clean. For instance, suppose that instead of hemoglobin levels we were to look at white blood cell count. Look at what happens:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "glucose_color_darkblue = ckd_combined[ckd_combined['Color'] == 'darkblue']\n", "glucose_color_gold = ckd_combined[ckd_combined['Color'] == 'gold']\n", "\n", "fig, ax = plt.subplots(figsize=(7,6))\n", "\n", "ax.scatter(glucose_color_darkblue['White Blood Cell Count'], \n", " glucose_color_darkblue['Glucose'], \n", " label='Color=darkblue', \n", " color='darkblue')\n", "\n", "ax.scatter(glucose_color_gold['White Blood Cell Count'], \n", " glucose_color_gold['Glucose'], \n", " label='Color=gold', \n", " color='gold')\n", "\n", "x_label = 'White Blood Cell Count'\n", "\n", "y_label = 'Glucose'\n", "\n", "y_vals = ax.get_yticks()\n", "\n", "plt.ylabel(y_label)\n", "\n", "ax.legend(bbox_to_anchor=(1.04,1), loc=\"upper left\")\n", "\n", "plt.xlabel(x_label)\n", "\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "As you can see, non-CKD individuals are all clustered in the lower-left. Most of the patients with CKD are above or to the right of that cluster... but not all. There are some patients with CKD who are in the lower left of the above figure (as indicated by the handful of blue dots scattered among the gold cluster). What this means is that you can't tell for certain whether someone has CKD from just these two blood test measurements.\n", "\n", "If we are given Alice's glucose level and white blood cell count, can we predict whether she has CKD? Yes, we can make a prediction, but we shouldn't expect it to be 100% accurate. Intuitively, it seems like there's a natural strategy for predicting: plot where Alice lands in the scatter plot; if she is in the lower-left, predict that she doesn't have CKD, otherwise predict she has CKD. \n", "\n", "This isn't perfect -- our predictions will sometimes be wrong. (Take a minute and think it through: for which patients will it make a mistake?) As the scatterplot above indicates, sometimes people with CKD have glucose and white blood cell levels that look identical to those of someone without CKD, so any classifier is inevitably going to make the wrong prediction for them.\n", "\n", "Can we automate this on a computer? Well, the nearest neighbor classifier would be a reasonable choice here too. Take a minute and think it through: how will its predictions compare to those from the intuitive strategy above? When will they differ?\n", "\n", "Its predictions will be pretty similar to our intuitive strategy, but occasionally it will make a different prediction. In particular, if Alice's blood test results happen to put her right near one of the blue dots in the lower-left, the intuitive strategy would predict 'not CKD', whereas the nearest neighbor classifier will predict 'CKD'.\n", "\n", "There is a simple generalization of the nearest neighbor classifier that fixes this anomaly. It is called the *k-nearest neighbor classifier*. To predict Alice's diagnosis, rather than looking at just the one neighbor closest to her, we can look at the 3 points that are closest to her, and use the diagnosis for each of those 3 points to predict Alice's diagnosis. In particular, we'll use the majority value among those 3 diagnoses as our prediction for Alice's diagnosis. Of course, there's nothing special about the number 3: we could use 4, or 5, or more. (It's often convenient to pick an odd number, so that we don't have to deal with ties.) In general, we pick a number $k$, and our predicted diagnosis for Alice is based on the $k$ patients in the training set who are closest to Alice. Intuitively, these are the $k$ patients whose blood test results were most similar to Alice, so it seems reasonable to use their diagnoses to predict Alice's diagnosis.\n", "\n", "The $k$-nearest neighbor classifier will now behave just like our intuitive strategy above." ] } ], "metadata": { "anaconda-cloud": {}, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 1 }