{ "cells": [ { "cell_type": "code", "execution_count": 16, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import numpy, scipy, matplotlib.pyplot as plt, sklearn, stanford_mir\n", "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[← Back to Index](index.html)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Cross Validation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "<a href=\"https://en.wikipedia.org/wiki/Cross-validation_(statistics)\">K-fold cross validation</a> is a method for evaluating the correctness of a classifier.\n", "\n", "For example, with 10-fold cross validation:\n", "\n", "1. Divide the data set into 10 random partitions.\n", "2. Choose one of the partitions as the test set. Train on the other nine partitions.\n", "3. Repeat for the partitions.\n", "\n", "Why cross validation is good?\n", "* In K-fold cross validation, evaluation on models can be done K times, but each time on a different partition of the data. \n", "* It can be used to tune parameters and to choose the best model and/or features." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Setup Data" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Load some features from ten kick drums and ten snare drums:\n", "* training_features is a 2 dimensional vector with zero crossing rate and spectral centroids as features of drum samples." ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [], "source": [ "training_features, training_labels, scaler = stanford_mir.get_features()" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[[-0.56578947 0.2490469 ]\n", " [-0.92105263 -0.63084112]\n", " [-1. -1. ]\n", " [-1. 0.32776395]\n", " [-0.77631579 -0.55766302]\n", " [-0.96052632 -0.69890631]\n", " [-0.68421053 -0.6353643 ]\n", " [-0.73684211 -0.10209073]\n", " [-0.82894737 -0.38692759]\n", " [-0.69736842 -0.5365142 ]\n", " [ 0.53947368 0.89837486]\n", " [ 0.63157895 0.55248334]\n", " [ 1. 0.99378504]\n", " [ 0.93421053 1. ]\n", " [ 0.85526316 0.87391641]\n", " [ 0.89473684 0.96797404]\n", " [ 0.23684211 0.86119147]\n", " [ 0.97368421 0.79526304]\n", " [ 0.40789474 0.38793404]\n", " [ 0.88157895 0.88600432]]\n", "[ 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 1. 1. 1. 1. 1. 1. 1. 1.\n", " 1. 1.]\n" ] } ], "source": [ "print training_features\n", "print training_labels" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot their features:" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<matplotlib.collections.PathCollection at 0x1c121d9b50>" ] }, "execution_count": 19, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYYAAAD8CAYAAABzTgP2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMS4wLCBo\ndHRwOi8vbWF0cGxvdGxpYi5vcmcvpW3flQAAGu9JREFUeJzt3X+MXeV95/H3J8Y2U1XN2HiS2AOO\njdbrhIqVnb2l2bXUJAQYJ39gl9LESFFMCrKSLV1po1ixRaWuaCOc+g92q7KbuIRA0gpIqONMRaKp\nwdBIu4H6WiYMBg0ezGaZGS92A5Nql6mxzXf/uGfIOcO9c+/MOffXzOclXd1znvOcc75+5vp+73nO\nc85RRGBmZjbtPe0OwMzMOosTg5mZZTgxmJlZhhODmZllODGYmVmGE4OZmWU4MZiZWYYTg5mZZTgx\nmJlZxiXtDmA+Vq1aFevWrWt3GGZmXeXYsWP/FBF99ep1ZWJYt24d5XK53WGYmXUVST9vpJ67kszM\nLMOJwczMMpwYzMwsw4nBzMwynBjMzCzDicHMzDIKSQyS7pd0RtLzNZZL0l9IGpX0nKSPpJbtlHQy\nee0sIh4zM5u/oq5jeAD4S+A7NZZ/CtiQvH4b+O/Ab0taCfwJUAICOCZpMCLeKCguM7OOdOj4OPuH\nRpiYnGJNbw+7BzayfXP/vOsVqZAjhoj4CfD6LFW2Ad+JiqeBXkmrgQHgcES8niSDw8DWImIyM+tU\nh46Ps/fgMOOTUwQwPjnF3oPDHDo+Pq96RWvVOYZ+4NXU/FhSVqvczGzB2j80wtT5i5myqfMX2T80\nMq96RWvVLTFUpSxmKX/3BqRdwC6AtWvXFheZmVkTVesKmpicqlp3Znmj9YrWqiOGMeCK1PzlwMQs\n5e8SEQciohQRpb6+uveAMjNru1pdQe/tWVq1/prenlnn65UXpVWJYRD4fDI66aPALyPiNDAE3CBp\nhaQVwA1JmZlZ16vVFSRBz9IlmfKepUvYPbAxU7Z7YGND9YpWSFeSpIeAjwOrJI1RGWm0FCAivgH8\nCPg0MAq8CXwhWfa6pD8FjiabuisiZjuJbWbWNWp1+Uy+eZ57Prup7mij6flWj0pSRNUu/Y5WKpXC\nt902s063Zd8Rxqskh/7eHv7HnmtbHo+kYxFRqlfPVz6bmTVJu7qC8urKB/WY2cLUjou5mqldXUF5\nOTGYWUeYHsEzfbJ2egQP0PFfpLPZvrm/6+J3YjCzjjDbxVzd9sUK3X3048RgZh2hXRdzFSWdCN7b\ns5T/99YFzl+sDO7ptqMfn3w2W4AOHR9ny74jrN/zGFv2HWn6vXWK0K6LuYow80K2yanz7ySFaa24\nlUVRnBjMFph23Xgtr24dwQPVu8Gq6ZajHycGswWmXTdey2v75n7uvulq+nt7EJWx/nffdHVXdL00\n+oXfDUc/4HMMZgtON/fVd+MIHqh84Ve7kC2tW45+wEcMZgtON/fVd6tq3WBL3yNW/NrSrjv6AR8x\nmC04uwc2Zq4HgO76tdqNuvVCtlqcGMwWmIX2JdUturUbrBonBrMFaCF9SS02nXBhnBODmVmH6JTb\ngvjks5lZh+iUocaFJAZJWyWNSBqVtKfK8nskPZu8XpI0mVp2MbVssIh4zMy6UacMNc7dlSRpCXAv\ncD2VZzgflTQYES9M14mI/5Sq/0fA5tQmpiJiU944zMy6Xa3rIVo91LiII4ZrgNGIOBURbwEPA9tm\nqX8L8FAB+zUzW1A65bYgRSSGfuDV1PxYUvYukj4IrAeOpIovlVSW9LSk7QXEY2bWlTrltiBFjEpS\nlbJaD5LeATwaEemzK2sjYkLSlcARScMR8fK7diLtAnYBrF27Nm/MZmYdqROGGhdxxDAGXJGavxyY\nqFF3BzO6kSJiInk/BTxF9vxDut6BiChFRKmvry9vzGZmVkMRieEosEHSeknLqHz5v2t0kaSNwArg\np6myFZKWJ9OrgC3ACzPXNTOz1sndlRQRFyTdAQwBS4D7I+KEpLuAckRMJ4lbgIcjIt3N9GHgm5Le\nppKk9qVHM5mZWesp+z3dHUqlUpTL5XaHYWbWVSQdi4hSvXq+8tnMzDKcGMzMLMOJwczMMpwYzMws\nw4nBzMwynBjMzCzDicHMzDKcGMzMLMOJwczMMpwYzMwsw4nBzMwyingeg5lZ2xw6Ps7+oREmJqdY\n09vD7oGNbX+eQbdzYjCzrnXo+Dh7Dw4zdb7y7K/xySn2HhwGcHLIwV1JZta19g+NvJMUpk2dv8j+\noZE2RbQwODGYWdeamJyaU7k1xonBzLrWmt6eOZVbYwpJDJK2ShqRNCppT5Xlt0o6K+nZ5HV7atlO\nSSeT184i4jGzxWH3wEZ6li7JlPUsXcLugY1timhhyH3yWdIS4F7gemAMOCppsMojOh+JiDtmrLsS\n+BOgBARwLFn3jbxxmdnCN32C2aOSilXEqKRrgNGIOAUg6WFgG9DIs5sHgMMR8Xqy7mFgK/BQAXGZ\n2SKwfXO/E0HBiuhK6gdeTc2PJWUz/Z6k5yQ9KumKOa6LpF2SypLKZ8+eLSBsMzOrpojEoCplMWP+\n74B1EfFvgMeBB+ewbqUw4kBElCKi1NfXN+9gzcxsdkUkhjHgitT85cBEukJE/CIiziWzfwX820bX\nNTOz1ioiMRwFNkhaL2kZsAMYTFeQtDo1eyPwYjI9BNwgaYWkFcANSZmZmbVJ7pPPEXFB0h1UvtCX\nAPdHxAlJdwHliBgE/qOkG4ELwOvArcm6r0v6UyrJBeCu6RPRZmbWHoqo2qXf0UqlUpTL5XaHYWbW\nVSQdi4hSvXq+8tnMzDKcGMzMLMO33TazefOzEBYmJwYzmxc/C2HhcleSmc2Ln4WwcDkxmNm8+FkI\nC5cTg5nNi5+FsHA5MZjZvPhZCAvXojr57BEUZsXxsxAWrkWTGDyCwqx4fhbCwrRoupI8gsLMrDGL\nJjF4BIWZWWMWTWLwCAozs8YsmsTgERRmZo1ZNCefPYLCzKwxhSQGSVuB/0rlQT33RcS+Gcu/DNxO\n5UE9Z4E/iIifJ8suAsNJ1f8dETcWEVM1HkFhZlZf7sQgaQlwL3A9lWc4H5U0GBEvpKodB0oR8aak\nLwF/Dnw2WTYVEZvyxmGN8/UcZjabIs4xXAOMRsSpiHgLeBjYlq4QEU9GxJvJ7NPA5QXs1+Zh+nqO\n8ckpgl9dz3Ho+Hi7QzOzDlFEYugHXk3NjyVltdwG/Dg1f6mksqSnJW0vIB6bha/nMLN6ijjHoCpl\nVR8kLelzQAn4WKp4bURMSLoSOCJpOCJerrLuLmAXwNq1a/NHvUj5eg4zq6eII4Yx4IrU/OXAxMxK\nkq4D7gRujIhz0+URMZG8nwKeAjZX20lEHIiIUkSU+vr6Cgh7cfL1HGZWTxGJ4SiwQdJ6ScuAHcBg\nuoKkzcA3qSSFM6nyFZKWJ9OrgC1A+qS1FczXc5hZPbm7kiLigqQ7gCEqw1Xvj4gTku4CyhExCOwH\nfh34viT41bDUDwPflPQ2lSS1b8ZoJiuYr+cws3oUUfV0QEcrlUpRLpfbHYaZWVeRdCwiSvXqLZpb\nYpiZWWOcGMzMLMOJwczMMpwYzMwsw4nBzMwynBjMzCzDicHMzDKcGMzMLMOJwczMMpwYzMwsw4nB\nzMwynBjMzCzDicHMzDKcGMzMLMOJwczMMgpJDJK2ShqRNCppT5XlyyU9kix/RtK61LK9SfmIpIEi\n4jEzs/nLnRgkLQHuBT4FXAXcIumqGdVuA96IiH8F3AN8PVn3KiqPAv1NYCvw35LtmZlZmxRxxHAN\nMBoRpyLiLeBhYNuMOtuAB5PpR4FPqvKMz23AwxFxLiJeAUaT7ZmZWZsUkRj6gVdT82NJWdU6EXEB\n+CVwWYPrmplZCxWRGFSlbOaDpGvVaWTdygakXZLKkspnz56dY4hmZtaoIhLDGHBFav5yYKJWHUmX\nAO8FXm9wXQAi4kBElCKi1NfXV0DYZmZWTRGJ4SiwQdJ6ScuonEwenFFnENiZTN8MHImISMp3JKOW\n1gMbgH8sICYzM5unS/JuICIuSLoDGAKWAPdHxAlJdwHliBgEvgV8V9IolSOFHcm6JyR9D3gBuAD8\nYURczBuTmZnNnyo/3LtLqVSKcrnc7jDMzLqKpGMRUapXz1c+m5lZhhODmZll5D7HYMU5dHyc/UMj\nTExOsaa3h90DG9m+2Zd1mFlrOTF0iEPHx9l7cJip85Vz7+OTU+w9OAzg5GBmLeWupA6xf2jknaQw\nber8RfYPjbQpIjNbrJwYOsTE5NScys3MmsWJoUOs6e2ZU7mZWbM4MXSI3QMb6VmaveN4z9Il7B7Y\n2KaIzGyx8snnDjF9gtmjksys3ZwYOsj2zf1OBGbWdu5KMjOzDCcGMzPLcGIwM7MMJwYzM8twYjAz\nswwnBjMzy8iVGCStlHRY0snkfUWVOpsk/VTSCUnPSfpsatkDkl6R9Gzy2pQnHjMzyy/vEcMe4ImI\n2AA8kczP9Cbw+Yj4TWAr8F8k9aaW746ITcnr2ZzxmJlZTnkTwzbgwWT6QWD7zAoR8VJEnEymJ4Az\nQF/O/ZqZWZPkTQzvj4jTAMn7+2arLOkaYBnwcqr4a0kX0z2Sls+y7i5JZUnls2fP5gzbzMxqqZsY\nJD0u6fkqr21z2ZGk1cB3gS9ExNtJ8V7gQ8BvASuBr9ZaPyIOREQpIkp9fT7gMDNrlrr3SoqI62ot\nk/SapNURcTr54j9To95vAI8BfxwRT6e2fTqZPCfp28BX5hS9mZkVLm9X0iCwM5neCfxwZgVJy4Af\nAN+JiO/PWLY6eReV8xPP54zHzMxyypsY9gHXSzoJXJ/MI6kk6b6kzmeA3wFurTIs9W8kDQPDwCrg\nz3LGY2ZmOSki2h3DnJVKpSiXy+0Ow8ysq0g6FhGlevV85bOZmWU4MZiZWYYTg5mZZTgxmJlZhhOD\nmZll1L3Azebu0PFx9g+NMDE5xZreHnYPbGT75v52h2Vm1hAnhoIdOj7O3oPDTJ2/CMD45BR7Dw4D\nODmYWVdwV1LB9g+NvJMUpk2dv8j+oZE2RWRmNjdODAWbmJyaU7mZWadxYijYmt6eOZWbmXUaJ4aC\n7R7YSM/SJZmynqVL2D2wsU0RmZnNjU8+F2z6BLNHJZlZt3JiaILtm/udCMysa7kryczMMpwYzMws\nI1dikLRS0mFJJ5P3FTXqXUw9pGcwVb5e0jPJ+o8kT3szM7M2ynvEsAd4IiI2AE8k89VMRcSm5HVj\nqvzrwD3J+m8At+WMx8zMcsqbGLYBDybTD1J5bnNDkuc8Xws8Op/1zcysOfImhvdHxGmA5P19Nepd\nKqks6WlJ01/+lwGTEXEhmR8DPJTHzKzN6g5XlfQ48IEqi+6cw37WRsSEpCuBI5KGgX+uUq/mA6gl\n7QJ2Aaxdu3YOu+5+vlurmbVS3cQQEdfVWibpNUmrI+K0pNXAmRrbmEjeT0l6CtgM/C3QK+mS5Kjh\ncmBiljgOAAcASqVSzQSy0PhurWbWanm7kgaBncn0TuCHMytIWiFpeTK9CtgCvBARATwJ3Dzb+gvR\noePjbNl3hPV7HmPLviMcOj5es67v1mpmrZY3MewDrpd0Erg+mUdSSdJ9SZ0PA2VJP6OSCPZFxAvJ\nsq8CX5Y0SuWcw7dyxtPxpo8AxienCH51BFArOfhurWbWarluiRERvwA+WaW8DNyeTP9P4Ooa658C\nrskTQ7eZ7QigWtfQmt4exqskAd+t1cyaxVc+t9hcjwB8t1YzazUnhhab6/Matm/u5+6brqa/twcB\n/b093H3T1T7xbGZN47urttjugY2ZUUZQ/wjAd2s1s1ZyYmgxP6/BzDqdE0Mb+AjAzDqZE8MsOu2K\n406Lx8wWJieGGjrtiuNOi8fMFi6PSqqh06447rR4zGzhcmKoodOuOO60eMxs4XJiqGGu1xs0W6fF\nY2YLlxNDDZ12xXGnxWNmC5dPPjP7aJ9OGQXUafGY2cKlyt2vu0upVIpyuVzItmaO9oHKL3HfdsLM\nFhpJxyKiVK/eou9K8mgfM7OsRZ8YPNrHzCxr0ScGj/YxM8vKlRgkrZR0WNLJ5H1FlTqfkPRs6vUv\nkrYnyx6Q9Epq2aY88cyHR/uYmWXlPWLYAzwRERuAJ5L5jIh4MiI2RcQm4FrgTeDvU1V2Ty+PiGdz\nxjNnft6BmVlW3uGq24CPJ9MPAk9ReY5zLTcDP46IN3Put1C+26mZ2a/kPWJ4f0ScBkje31en/g7g\noRllX5P0nKR7JC2vtaKkXZLKkspnz57NF7WZmdVUNzFIelzS81Ve2+ayI0mrgauBoVTxXuBDwG8B\nK5nlaCMiDkREKSJKfX19c9m1mZnNQd2upIi4rtYySa9JWh0Rp5Mv/jOzbOozwA8i4nxq26eTyXOS\nvg18pcG4zcysSfJ2JQ0CO5PpncAPZ6l7CzO6kZJkgiQB24Hnc8ZjZmY55U0M+4DrJZ0Erk/mkVSS\ndN90JUnrgCuAf5ix/t9IGgaGgVXAn+WMx8zMcso1KikifgF8skp5Gbg9Nf+/gHcN+4mIa/Ps38zM\nirfor3w2M7MsJwYzM8twYjAzswwnBjMzy3BiMDOzDCcGMzPLcGIwM7MMJwYzM8twYjAzswwnBjMz\ny3BiMDOzDCcGMzPLcGIwM7MMJwYzM8twYjAzs4xciUHS70s6IeltSaVZ6m2VNCJpVNKeVPl6Sc9I\nOinpEUnL8sRjZmb55T1ieB64CfhJrQqSlgD3Ap8CrgJukXRVsvjrwD0RsQF4A7gtZzyzOnR8nC37\njrB+z2Ns2XeEQ8fHm7k7M7OulCsxRMSLETFSp9o1wGhEnIqIt4CHgW3Jc56vBR5N6j1I5bnPTXHo\n+Dh7Dw4zPjlFAOOTU+w9OOzkYGY2QyvOMfQDr6bmx5Kyy4DJiLgwo7wp9g+NMHX+YqZs6vxF9g/V\ny2tmZotL3Wc+S3oc+ECVRXdGxA8b2IeqlMUs5bXi2AXsAli7dm0Du82amJyaU7mZ2WJVNzFExHU5\n9zEGXJGavxyYAP4J6JV0SXLUMF1eK44DwAGAUqlUM4HUsqa3h/EqSWBNb89cN2VmtqC1oivpKLAh\nGYG0DNgBDEZEAE8CNyf1dgKNHIHMy+6BjfQsXZIp61m6hN0DG5u1SzOzrpR3uOrvShoD/h3wmKSh\npHyNpB8BJEcDdwBDwIvA9yLiRLKJrwJfljRK5ZzDt/LEM5vtm/u5+6ar6e/tQUB/bw9333Q12zc3\n7bSGmVlXUuWHe3cplUpRLpfbHYaZWVeRdCwial5zNs1XPpuZWYYTg5mZZTgxmJlZhhODmZllODGY\nmVmGE4OZmWV05XBVSWeBn+fYxCoqV153Gsc1N50YVyfGBI5rLjoxJigmrg9GRF+9Sl2ZGPKSVG5k\nLG+rOa656cS4OjEmcFxz0YkxQWvjcleSmZllODGYmVnGYk0MB9odQA2Oa246Ma5OjAkc11x0YkzQ\nwrgW5TkGMzOrbbEeMZiZWQ0LNjFI+n1JJyS9LanmmXxJWyWNSBqVtCdVvl7SM5JOSnokeZZEEXGt\nlHQ42e5hSSuq1PmEpGdTr3+RtD1Z9oCkV1LLNrUqrqTexdS+B1PlhbdXg221SdJPk7/1c5I+m1pW\naFvV+qykli9P/u2jSVusSy3bm5SPSBrIE8ccY/qypBeStnlC0gdTy6r+LVsU162Szqb2f3tq2c7k\nb35S0s4Wx3VPKqaXJE2mljWlvSTdL+mMpOdrLJekv0hifk7SR1LLmtNWEbEgX8CHgY3AU0CpRp0l\nwMvAlcAy4GfAVcmy7wE7kulvAF8qKK4/B/Yk03uAr9epvxJ4Hfi1ZP4B4OYmtFdDcQH/t0Z54e3V\nSEzAvwY2JNNrgNNAb9FtNdtnJVXnPwDfSKZ3AI8k01cl9ZcD65PtLGlRTJ9IfXa+NB3TbH/LFsV1\nK/CXNT7vp5L3Fcn0ilbFNaP+HwH3t6C9fgf4CPB8jeWfBn5M5XHIHwWeaXZbLdgjhoh4MSJG6lS7\nBhiNiFMR8RbwMLBNkoBrgUeTeg8C2wsKbVuyvUa3ezPw44h4s6D91zLXuN7RxPaqG1NEvBQRJ5Pp\nCeAMUPcCnnmo+lmZJd5HgU8mbbMNeDgizkXEK8Bosr2mxxQRT6Y+O09TeYRuszXSVrUMAIcj4vWI\neAM4DGxtU1y3AA8VtO+aIuInVH781bIN+E5UPE3lkciraWJbLdjE0KB+4NXU/FhSdhkwGZWnz6XL\ni/D+iDgNkLy/r079Hbz7w/m15JDyHknLWxzXpZLKkp6e7t6iee01p7aSdA2VX4Ivp4qLaqtan5Wq\ndZK2+CWVtmlk3WbFlHYblV+e06r9LYvQaFy/l/xtHpU0/Vz4ZrXVnLaddLmtB46kipvVXvXUirtp\nbXVJERtpF0mPAx+osujOiGjk+dGqUhazlOeOq9FtJNtZDVxN5bGo0/YC/4fKF+ABKo9HvauFca2N\niAlJVwJHJA0D/1ylXkPtVXBbfRfYGRFvJ8Xzbqtqu6hSNvPf2JTP0ywa3q6kzwEl4GOp4nf9LSPi\n5WrrNyGuvwMeiohzkr5I5Ujr2gbXbWZc03YAj0bExVRZs9qrnlZ/rro7MUTEdTk3MQZckZq/HJig\ncj+SXkmXJL/8pstzxyXpNUmrI+J08mV2ZpZNfQb4QUScT237dDJ5TtK3ga+0Mq6ku4aIOCXpKWAz\n8LfMs72KiEnSbwCPAX+cHGpPb3vebVVFrc9KtTpjki4B3kuli6CRdZsVE5Kuo5JoPxYR56bLa/wt\ni/iiqxtXRPwiNftXwNdT6358xrpPFRBTQ3Gl7AD+MF3QxPaqp1bcTWurxd6VdBTYoMqImmVUPgyD\nUTmz8ySV/n2AnUAjRyCNGEy218h239XHmXxBTvfrbweqjmRoRlySVkx3x0haBWwBXmhiezUS0zLg\nB1T6YL8/Y1mRbVX1szJLvDcDR5K2GQR2qDJqaT2wAfjHHLE0HJOkzcA3gRsj4kyqvOrfsoCYGo1r\ndWr2RuDFZHoIuCGJbwVwA9kj5qbGlcS2kcrJ3J+myprZXvUMAp9PRid9FPhl8qOneW3VjLPsnfAC\nfpdKRj0HvAYMJeVrgB+l6n0aeIlK5r8zVX4llf+8o8D3geUFxXUZ8ARwMnlfmZSXgPtS9dYB48B7\nZqx/BBim8iX318Cvtyou4N8n+/5Z8n5bM9urwZg+B5wHnk29NjWjrap9Vqh0Td2YTF+a/NtHk7a4\nMrXuncl6I8CnCvyc14vp8eTzP902g/X+li2K627gRLL/J4EPpdb9g6QNR4EvtDKuZP4/A/tmrNe0\n9qLy4+908jkeo3Iu6IvAF5PlAu5NYh4mNcqyWW3lK5/NzCxjsXclmZnZDE4MZmaW4cRgZmYZTgxm\nZpbhxGBmZhlODGZmluHEYGZmGU4MZmaW8f8B0tI9wTs74GUAAAAASUVORK5CYII=\n", "text/plain": [ "<matplotlib.figure.Figure at 0x1c11f931d0>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.scatter(training_features[:,0], training_features[:,1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load model and perform cross-validation" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Initialize the K-Nearest Neighbor classifier:\n", "* Note that K=3 here (3 neighbors), but K is so-called a hyperparameter, which can be tuned to select the best value (We will get to that below!)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "collapsed": true }, "outputs": [], "source": [ "model = sklearn.neighbors.KNeighborsClassifier(n_neighbors=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Perform 5-fold cross validation (cv=5):" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "acc = sklearn.model_selection.cross_val_score(model, training_features, training_labels, cv=5)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 1. 1. 1. 1. 1.]\n", "1.0\n" ] } ], "source": [ "print acc\n", "print acc.mean()" ] }, { "cell_type": "markdown", "metadata": { "collapsed": true }, "source": [ "### Tuning parameter K" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since value of K has been arbitrarily chosen, we do not know whether it was the best choice (although here in this example, we have a perfect score anyway...). \n", "Therefore, testing the result with several other values will help to choose the best parameter." ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[K=1] Accuracy=1.000\n", "[K=2] Accuracy=1.000\n", "[K=3] Accuracy=1.000\n", "[K=4] Accuracy=1.000\n", "[K=5] Accuracy=1.000\n" ] } ], "source": [ "K_choices = [1,2,3,4,5]\n", "for k in K_choices:\n", " model = sklearn.neighbors.KNeighborsClassifier(n_neighbors=k)\n", " mean_score = sklearn.model_selection.cross_val_score(model, training_features, training_labels, cv=5).mean()\n", " print \"[K=%d] Accuracy=%.3f\"%(k, mean_score)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Bonus : Try out other models and perform cross validations to find the best model!" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# Logistic regression for binary classification\n", "model = sklearn.linear_model.LogisticRegression()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "[← Back to Index](index.html)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 2", "language": "python", "name": "python2" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.14" } }, "nbformat": 4, "nbformat_minor": 1 }