{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "path_data = '../../data/'\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import math\n", "import scipy.stats as stats\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "plt.style.use('fivethirtyeight')\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "def standard_units(x):\n", " return (x - np.mean(x))/np.std(x)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "def distance(point1, point2):\n", " \"\"\"The distance between two arrays of numbers.\"\"\"\n", " return np.sqrt(np.sum((point1 - point2)**2))\n", "\n", "def all_distances(training, point):\n", " \"\"\"The distance between p (an array of numbers) and the numbers in row i of attribute_table.\"\"\"\n", " attributes = training.drop(columns=['Class'])\n", " def distance_from_point(row):\n", " return distance(point, np.array(row))\n", " return attributes.apply(distance_from_point, axis=1)\n", "\n", "def table_with_distances(training, point):\n", " \"\"\"A copy of the training table with the distance from each row to array p.\"\"\"\n", " training1 = training.copy()\n", " training1['Distance'] = all_distances(training1, point)\n", " return training1\n", "\n", "def closest(training, point, k):\n", " \"\"\"A table containing the k closest rows in the training table to array p.\"\"\"\n", " with_dists = table_with_distances(training, point)\n", " sorted_by_distance = with_dists.sort_values(by=['Distance'])\n", " topk = sorted_by_distance.take(np.arange(k))\n", " return topk\n", "\n", "def majority(topkclasses):\n", " \"\"\"1 if the majority of the \"Class\" column is 1s, and 0 otherwise.\"\"\" \n", " ones = len(topkclasses[topkclasses['Class'] == 1])\n", " zeros = len(topkclasses[topkclasses['Class'] == 0])\n", " if ones > zeros:\n", " return 1\n", " else:\n", " return 0\n", "\n", "def classify(training, p, k):\n", " \"\"\"Classify an example with attributes p using k-nearest neighbor classification with the given training table.\"\"\"\n", " closestk = closest(training, p, k)\n", " topkclasses = closestk[['Class']]\n", " return majority(closestk)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Nearest Neighbors\n", "In this section we'll develop the *nearest neighbor* method of classification. Just focus on the ideas for now and don't worry if some of the code is mysterious. Later in the chapter we'll see how to organize our ideas into code that performs the classification." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Chronic kidney disease\n", "Let's work through an example. We're going to work with a data set that was collected to help doctors diagnose chronic kidney disease (CKD). Each row in the data set represents a single patient who was treated in the past and whose diagnosis is known. For each patient, we have a bunch of measurements from a blood test. We'd like to find which measurements are most useful for diagnosing CKD, and develop a way to classify future patients as \"has CKD\" or \"doesn't have CKD\" based on their blood test results." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Age | \n", "Blood Pressure | \n", "Specific Gravity | \n", "Albumin | \n", "Sugar | \n", "Red Blood Cells | \n", "Pus Cell | \n", "Pus Cell clumps | \n", "Bacteria | \n", "Glucose | \n", "... | \n", "Packed Cell Volume | \n", "White Blood Cell Count | \n", "Red Blood Cell Count | \n", "Hypertension | \n", "Diabetes Mellitus | \n", "Coronary Artery Disease | \n", "Appetite | \n", "Pedal Edema | \n", "Anemia | \n", "Class | \n", "
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | \n", "48 | \n", "70 | \n", "1.005 | \n", "4 | \n", "0 | \n", "normal | \n", "abnormal | \n", "present | \n", "notpresent | \n", "117 | \n", "... | \n", "32 | \n", "6700 | \n", "3.9 | \n", "yes | \n", "no | \n", "no | \n", "poor | \n", "yes | \n", "yes | \n", "1 | \n", "
1 | \n", "53 | \n", "90 | \n", "1.020 | \n", "2 | \n", "0 | \n", "abnormal | \n", "abnormal | \n", "present | \n", "notpresent | \n", "70 | \n", "... | \n", "29 | \n", "12100 | \n", "3.7 | \n", "yes | \n", "yes | \n", "no | \n", "poor | \n", "no | \n", "yes | \n", "1 | \n", "
2 | \n", "63 | \n", "70 | \n", "1.010 | \n", "3 | \n", "0 | \n", "abnormal | \n", "abnormal | \n", "present | \n", "notpresent | \n", "380 | \n", "... | \n", "32 | \n", "4500 | \n", "3.8 | \n", "yes | \n", "yes | \n", "no | \n", "poor | \n", "yes | \n", "no | \n", "1 | \n", "
3 | \n", "68 | \n", "80 | \n", "1.010 | \n", "3 | \n", "2 | \n", "normal | \n", "abnormal | \n", "present | \n", "present | \n", "157 | \n", "... | \n", "16 | \n", "11000 | \n", "2.6 | \n", "yes | \n", "yes | \n", "yes | \n", "poor | \n", "yes | \n", "no | \n", "1 | \n", "
4 | \n", "61 | \n", "80 | \n", "1.015 | \n", "2 | \n", "0 | \n", "abnormal | \n", "abnormal | \n", "notpresent | \n", "notpresent | \n", "173 | \n", "... | \n", "24 | \n", "9200 | \n", "3.2 | \n", "yes | \n", "yes | \n", "yes | \n", "poor | \n", "yes | \n", "yes | \n", "1 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
153 | \n", "55 | \n", "80 | \n", "1.020 | \n", "0 | \n", "0 | \n", "normal | \n", "normal | \n", "notpresent | \n", "notpresent | \n", "140 | \n", "... | \n", "47 | \n", "6700 | \n", "4.9 | \n", "no | \n", "no | \n", "no | \n", "good | \n", "no | \n", "no | \n", "0 | \n", "
154 | \n", "42 | \n", "70 | \n", "1.025 | \n", "0 | \n", "0 | \n", "normal | \n", "normal | \n", "notpresent | \n", "notpresent | \n", "75 | \n", "... | \n", "54 | \n", "7800 | \n", "6.2 | \n", "no | \n", "no | \n", "no | \n", "good | \n", "no | \n", "no | \n", "0 | \n", "
155 | \n", "12 | \n", "80 | \n", "1.020 | \n", "0 | \n", "0 | \n", "normal | \n", "normal | \n", "notpresent | \n", "notpresent | \n", "100 | \n", "... | \n", "49 | \n", "6600 | \n", "5.4 | \n", "no | \n", "no | \n", "no | \n", "good | \n", "no | \n", "no | \n", "0 | \n", "
156 | \n", "17 | \n", "60 | \n", "1.025 | \n", "0 | \n", "0 | \n", "normal | \n", "normal | \n", "notpresent | \n", "notpresent | \n", "114 | \n", "... | \n", "51 | \n", "7200 | \n", "5.9 | \n", "no | \n", "no | \n", "no | \n", "good | \n", "no | \n", "no | \n", "0 | \n", "
157 | \n", "58 | \n", "80 | \n", "1.025 | \n", "0 | \n", "0 | \n", "normal | \n", "normal | \n", "notpresent | \n", "notpresent | \n", "131 | \n", "... | \n", "53 | \n", "6800 | \n", "6.1 | \n", "no | \n", "no | \n", "no | \n", "good | \n", "no | \n", "no | \n", "0 | \n", "
158 rows × 25 columns
\n", "\n", " | Hemoglobin | \n", "Glucose | \n", "White Blood Cell Count | \n", "Class | \n", "
---|---|---|---|---|
0 | \n", "-0.865744 | \n", "-0.221549 | \n", "-0.569768 | \n", "1 | \n", "
1 | \n", "-1.457446 | \n", "-0.947597 | \n", "1.162684 | \n", "1 | \n", "
2 | \n", "-1.004968 | \n", "3.841231 | \n", "-1.275582 | \n", "1 | \n", "
3 | \n", "-2.814879 | \n", "0.396364 | \n", "0.809777 | \n", "1 | \n", "
4 | \n", "-2.083954 | \n", "0.643529 | \n", "0.232293 | \n", "1 | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
153 | \n", "0.700526 | \n", "0.133751 | \n", "-0.569768 | \n", "0 | \n", "
154 | \n", "0.978974 | \n", "-0.870358 | \n", "-0.216861 | \n", "0 | \n", "
155 | \n", "0.735332 | \n", "-0.484162 | \n", "-0.601850 | \n", "0 | \n", "
156 | \n", "0.178436 | \n", "-0.267893 | \n", "-0.409356 | \n", "0 | \n", "
157 | \n", "0.735332 | \n", "-0.005280 | \n", "-0.537686 | \n", "0 | \n", "
158 rows × 4 columns
\n", "\n", " | Hemoglobin | \n", "Glucose | \n", "White Blood Cell Count | \n", "Class | \n", "Color | \n", "
---|---|---|---|---|---|
0 | \n", "-0.865744 | \n", "-0.221549 | \n", "-0.569768 | \n", "1 | \n", "darkblue | \n", "
1 | \n", "-1.457446 | \n", "-0.947597 | \n", "1.162684 | \n", "1 | \n", "darkblue | \n", "
2 | \n", "-1.004968 | \n", "3.841231 | \n", "-1.275582 | \n", "1 | \n", "darkblue | \n", "
3 | \n", "-2.814879 | \n", "0.396364 | \n", "0.809777 | \n", "1 | \n", "darkblue | \n", "
4 | \n", "-2.083954 | \n", "0.643529 | \n", "0.232293 | \n", "1 | \n", "darkblue | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
153 | \n", "0.700526 | \n", "0.133751 | \n", "-0.569768 | \n", "0 | \n", "gold | \n", "
154 | \n", "0.978974 | \n", "-0.870358 | \n", "-0.216861 | \n", "0 | \n", "gold | \n", "
155 | \n", "0.735332 | \n", "-0.484162 | \n", "-0.601850 | \n", "0 | \n", "gold | \n", "
156 | \n", "0.178436 | \n", "-0.267893 | \n", "-0.409356 | \n", "0 | \n", "gold | \n", "
157 | \n", "0.735332 | \n", "-0.005280 | \n", "-0.537686 | \n", "0 | \n", "gold | \n", "
158 rows × 5 columns
\n", "\n", " | Hemoglobin | \n", "Glucose | \n", "
---|---|---|
0 | \n", "-2.0 | \n", "-2.0 | \n", "
1 | \n", "-2.0 | \n", "-1.9 | \n", "
2 | \n", "-2.0 | \n", "-1.8 | \n", "
3 | \n", "-2.0 | \n", "-1.7 | \n", "
4 | \n", "-2.0 | \n", "-1.6 | \n", "
... | \n", "... | \n", "... | \n", "
1676 | \n", "2.0 | \n", "1.6 | \n", "
1677 | \n", "2.0 | \n", "1.7 | \n", "
1678 | \n", "2.0 | \n", "1.8 | \n", "
1679 | \n", "2.0 | \n", "1.9 | \n", "
1680 | \n", "2.0 | \n", "2.0 | \n", "
1681 rows × 2 columns
\n", "\n", " | Hemoglobin | \n", "Glucose | \n", "Class | \n", "
---|---|---|---|
0 | \n", "-2.0 | \n", "-2.0 | \n", "1 | \n", "
1 | \n", "-2.0 | \n", "-1.9 | \n", "1 | \n", "
2 | \n", "-2.0 | \n", "-1.8 | \n", "1 | \n", "
3 | \n", "-2.0 | \n", "-1.7 | \n", "1 | \n", "
4 | \n", "-2.0 | \n", "-1.6 | \n", "1 | \n", "
\n", " | Hemoglobin | \n", "Glucose | \n", "Class | \n", "Color | \n", "
---|---|---|---|---|
0 | \n", "-2.0 | \n", "-2.0 | \n", "1 | \n", "darkblue | \n", "
1 | \n", "-2.0 | \n", "-1.9 | \n", "1 | \n", "darkblue | \n", "
2 | \n", "-2.0 | \n", "-1.8 | \n", "1 | \n", "darkblue | \n", "
3 | \n", "-2.0 | \n", "-1.7 | \n", "1 | \n", "darkblue | \n", "
4 | \n", "-2.0 | \n", "-1.6 | \n", "1 | \n", "darkblue | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1676 | \n", "2.0 | \n", "0.6 | \n", "0 | \n", "gold | \n", "
1677 | \n", "2.0 | \n", "0.7 | \n", "0 | \n", "gold | \n", "
1678 | \n", "2.0 | \n", "0.8 | \n", "0 | \n", "gold | \n", "
1679 | \n", "2.0 | \n", "0.9 | \n", "0 | \n", "gold | \n", "
1680 | \n", "2.0 | \n", "1.0 | \n", "0 | \n", "gold | \n", "
1681 rows × 4 columns
\n", "