{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "path_data = '../../../../data/'\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import math\n", "import scipy.stats as stats\n", "\n", "%matplotlib inline\n", "import matplotlib.pyplot as plt\n", "from mpl_toolkits.mplot3d import Axes3D\n", "plt.style.use('fivethirtyeight')\n", "\n", "import warnings\n", "warnings.filterwarnings('ignore')" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [ "remove_input" ] }, "outputs": [], "source": [ "def standard_units(x):\n", " return (x - np.mean(x))/np.std(x)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def distance(point1, point2):\n", " \"\"\"The distance between two arrays of numbers.\"\"\"\n", " return np.sqrt(np.sum((point1 - point2)**2))\n", "\n", "def all_distances(training, point):\n", " \"\"\"The distance between p (an array of numbers) and the numbers in row i of attribute_table.\"\"\"\n", " attributes = training.drop(columns=['Class'])\n", " def distance_from_point(row):\n", " return distance(point, np.array(row))\n", " return attributes.apply(distance_from_point, axis=1)\n", "\n", "def table_with_distances(training, point):\n", " \"\"\"A copy of the training table with the distance from each row to array p.\"\"\"\n", " training1 = training.copy()\n", " training1['Distance'] = all_distances(training1, point)\n", " return training1\n", "\n", "def closest(training, point, k):\n", " \"\"\"A table containing the k closest rows in the training table to array p.\"\"\"\n", " with_dists = table_with_distances(training, point)\n", " sorted_by_distance = with_dists.sort_values(by=['Distance'])\n", " topk = sorted_by_distance.take(np.arange(k))\n", " return topk" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "def classify_grid(training, test, k):\n", " ckd_new1 = pd.DataFrame(columns=['Hemoglobin', 'Glucose', 'Class', 'Distance'])\n", " empty = np.array([])\n", " for i in range(len(test)):\n", " ckd_new2 = closest(training, np.array([test.iloc[i]]), k)\n", " ones = len(ckd_new2[ckd_new2['Class'] == '1'])\n", " zeros = len(ckd_new2[ckd_new2['Class'] == '0'])\n", " if ones > zeros:\n", " empty = np.append(empty, 1)\n", " else:\n", " empty = np.append(empty, 0)\n", " \n", " return empty" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | Hemoglobin | \n", "Glucose | \n", "White Blood Cell Count | \n", "Class | \n", "Color | \n", "
---|---|---|---|---|---|
0 | \n", "-0.865744 | \n", "-0.221549 | \n", "-0.569768 | \n", "1 | \n", "darkblue | \n", "
1 | \n", "-1.457446 | \n", "-0.947597 | \n", "1.162684 | \n", "1 | \n", "darkblue | \n", "
2 | \n", "-1.004968 | \n", "3.841231 | \n", "-1.275582 | \n", "1 | \n", "darkblue | \n", "
3 | \n", "-2.814879 | \n", "0.396364 | \n", "0.809777 | \n", "1 | \n", "darkblue | \n", "
4 | \n", "-2.083954 | \n", "0.643529 | \n", "0.232293 | \n", "1 | \n", "darkblue | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
153 | \n", "0.700526 | \n", "0.133751 | \n", "-0.569768 | \n", "0 | \n", "gold | \n", "
154 | \n", "0.978974 | \n", "-0.870358 | \n", "-0.216861 | \n", "0 | \n", "gold | \n", "
155 | \n", "0.735332 | \n", "-0.484162 | \n", "-0.601850 | \n", "0 | \n", "gold | \n", "
156 | \n", "0.178436 | \n", "-0.267893 | \n", "-0.409356 | \n", "0 | \n", "gold | \n", "
157 | \n", "0.735332 | \n", "-0.005280 | \n", "-0.537686 | \n", "0 | \n", "gold | \n", "
158 rows × 5 columns
\n", "\n", " | Glucose | \n", "White Blood Cell Count | \n", "Class | \n", "
---|---|---|---|
0 | \n", "-2.0 | \n", "-2.00 | \n", "0 | \n", "
1 | \n", "-2.0 | \n", "-1.75 | \n", "0 | \n", "
2 | \n", "-2.0 | \n", "-1.50 | \n", "0 | \n", "
3 | \n", "-2.0 | \n", "-1.25 | \n", "0 | \n", "
4 | \n", "-2.0 | \n", "-1.00 | \n", "0 | \n", "
\n", " | Glucose | \n", "White Blood Cell Count | \n", "Class | \n", "Color | \n", "
---|---|---|---|---|
0 | \n", "-2.0 | \n", "-2.00 | \n", "0 | \n", "gold | \n", "
1 | \n", "-2.0 | \n", "-1.75 | \n", "0 | \n", "gold | \n", "
2 | \n", "-2.0 | \n", "-1.50 | \n", "0 | \n", "gold | \n", "
3 | \n", "-2.0 | \n", "-1.25 | \n", "0 | \n", "gold | \n", "
4 | \n", "-2.0 | \n", "-1.00 | \n", "0 | \n", "gold | \n", "
... | \n", "... | \n", "... | \n", "... | \n", "... | \n", "
1084 | \n", "6.0 | \n", "5.00 | \n", "1 | \n", "darkblue | \n", "
1085 | \n", "6.0 | \n", "5.25 | \n", "1 | \n", "darkblue | \n", "
1086 | \n", "6.0 | \n", "5.50 | \n", "1 | \n", "darkblue | \n", "
1087 | \n", "6.0 | \n", "5.75 | \n", "1 | \n", "darkblue | \n", "
1088 | \n", "6.0 | \n", "6.00 | \n", "1 | \n", "darkblue | \n", "
1089 rows × 4 columns
\n", "