Commit 58de3a15 authored by Steve Tjoa's avatar Steve Tjoa

Merge branch 'master' of github.com:stevetjoa/stanford-mir

parents 82e1aca2 6fe2a0a5
{
"metadata": {
"name": "",
"signature": "sha256:75f989dd2ab73ca4005602cfbe578941785ac582da4dd30251ada2d364be673a"
"signature": "sha256:0949f131655e064cf3e40a0f5c298b8bd0662c1299e6aa2be192c8ce2d4c160a"
},
"nbformat": 3,
"nbformat_minor": 0,
......@@ -106,6 +106,11 @@
"cell_type": "code",
"collapsed": false,
"input": [
"import numpy as np\n",
"from sklearn import cross_validation\n",
"from sklearn.neighbors import KNeighborsClassifier\n",
"from sklearn import preprocessing\n",
"\n",
"def crossValidateKNN(features, labels):\n",
" \"\"\"\n",
" This code is provided as a template for your cross-validation\n",
......@@ -130,15 +135,14 @@
" print(\"TEST: %s\" % test_index)\n",
" \n",
" # SCALE\n",
" trainingFeatures, mf, sf = scale(features.take(train_index, 0))\n",
" scaler = preprocessing.MinMaxScaler(feature_range = (-1, 1))\n",
" trainingFeatures = scaler.fit_transform(features.take(train_index, 0))\n",
" # BUILD NEW MODEL - ADD YOUR MODEL BUILDING CODE HERE...\n",
" # model = knn(numFeatures, 2, 3, trainingFeatures, labels[train_index, :]) \n",
" model = KNeighborsClassifier(n_neighbors = 3)\n",
" model.fit(trainingFeatures, labels.take(train_index, 0))\n",
" # RESCALE TEST DATA TO TRAINING SCALE SPACE\n",
" testingFeatures = rescale(features.take(test_index, 0), mf, sf)\n",
" testingFeatures = scaler.transform(features.take(test_index, 0))\n",
" # EVALUATE WITH TEST DATA - ADD YOUR MODEL EVALUATION CODE HERE\n",
" # voting, model_output = knnfwd(model, testingFeatures)\n",
" model_output = model.predict(testingFeatures)\n",
" print(\"KNN prediction %s\" % model_output) # Debugging.\n",
" # CONVERT labels(test,:) LABELS TO SAME FORMAT TO COMPUTE ERROR \n",
......@@ -147,7 +151,7 @@
" matches = model_output != labels_test\n",
" errors[foldIndex] = matches.mean()\n",
" print('cross validation error: %f' % errors.mean())\n",
" print('cross validation accuracy: %f' % (1.0 - errors.mean()))"
" print('cross validation accuracy: %f' % (1.0 - errors.mean()))\n"
],
"language": "python",
"metadata": {},
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment