In [3]:
import numpy, scipy, matplotlib.pyplot as plt, sklearn, stanford_mir
%matplotlib inline

Cross Validation

K-fold cross validation is a method for evaluating the correctness of a classifier.

For example, with 10-fold cross validation:

  1. Divide the data set into 10 random partitions.
  2. Choose one of the partitions as the test set. Train on the other nine partitions.
  3. Repeat for the partitions.

Load some features from ten kick drums and ten snare drums:

In [4]:
training_features, training_labels, scaler = stanford_mir.get_features()
---------------------------------------------------------------------------
IOError                                   Traceback (most recent call last)
<ipython-input-4-a7910ded3ac4> in <module>()
----> 1 training_features, training_labels, scaler = stanford_mir.get_features()

/Users/steve/stanford-mir/stanford_mir/core.pyc in get_features(collection, features, scaler)
     20                  scaler=None):
     21     if collection == 'drum_samples_train':
---> 22         kick_filepaths, snare_filepaths = download_samples('drum_samples_train')
     23         kick_signals = [
     24             librosa.load(p)[0] for p in kick_filepaths

/Users/steve/stanford-mir/stanford_mir/core.pyc in download_samples(collection)
     86                 filename = '%s_%02d.wav' % (drum_type, i)
     87                 urllib.urlretrieve('http://audio.musicinformationretrieval.com/drum_samples/%s' % filename,
---> 88                                    filename=os.path.join(collection, filename))
     89         kick_filepaths = [os.path.join(collection, 'kick_%02d.wav' % i) for i in range(1, 11)]
     90         snare_filepaths = [os.path.join(collection, 'snare_%02d.wav' % i) for i in range(1, 11)]

/usr/local/Cellar/python/2.7.6_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/urllib.pyc in urlretrieve(url, filename, reporthook, data)
     92     if not _urlopener:
     93         _urlopener = FancyURLopener()
---> 94     return _urlopener.retrieve(url, filename, reporthook, data)
     95 def urlcleanup():
     96     if _urlopener:

/usr/local/Cellar/python/2.7.6_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/urllib.pyc in retrieve(self, url, filename, reporthook, data)
    238             except IOError:
    239                 pass
--> 240         fp = self.open(url, data)
    241         try:
    242             headers = fp.info()

/usr/local/Cellar/python/2.7.6_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/urllib.pyc in open(self, fullurl, data)
    206         try:
    207             if data is None:
--> 208                 return getattr(self, name)(url)
    209             else:
    210                 return getattr(self, name)(url, data)

/usr/local/Cellar/python/2.7.6_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/urllib.pyc in open_http(self, url, data)
    343         if realhost: h.putheader('Host', realhost)
    344         for args in self.addheaders: h.putheader(*args)
--> 345         h.endheaders(data)
    346         errcode, errmsg, headers = h.getreply()
    347         fp = h.getfile()

/usr/local/Cellar/python/2.7.6_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/httplib.pyc in endheaders(self, message_body)
    967         else:
    968             raise CannotSendHeader()
--> 969         self._send_output(message_body)
    970 
    971     def request(self, method, url, body=None, headers={}):

/usr/local/Cellar/python/2.7.6_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/httplib.pyc in _send_output(self, message_body)
    827             msg += message_body
    828             message_body = None
--> 829         self.send(msg)
    830         if message_body is not None:
    831             #message_body was not a string (i.e. it is a file) and

/usr/local/Cellar/python/2.7.6_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/httplib.pyc in send(self, data)
    789         if self.sock is None:
    790             if self.auto_open:
--> 791                 self.connect()
    792             else:
    793                 raise NotConnected()

/usr/local/Cellar/python/2.7.6_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/httplib.pyc in connect(self)
    770         """Connect to the host and port specified in __init__."""
    771         self.sock = socket.create_connection((self.host,self.port),
--> 772                                              self.timeout, self.source_address)
    773 
    774         if self._tunnel_host:

/usr/local/Cellar/python/2.7.6_1/Frameworks/Python.framework/Versions/2.7/lib/python2.7/socket.pyc in create_connection(address, timeout, source_address)
    551     host, port = address
    552     err = None
--> 553     for res in getaddrinfo(host, port, 0, SOCK_STREAM):
    554         af, socktype, proto, canonname, sa = res
    555         sock = None

IOError: [Errno socket error] [Errno 8] nodename nor servname provided, or not known
In [13]:
print training_labels
[ 0.  0.  0.  0.  0.  0.  0.  0.  0.  0.  1.  1.  1.  1.  1.  1.  1.  1.
  1.  1.]

Plot their features:

In [2]:
plt.scatter(training_features[:,0], training_features[:,1])
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-2-f00bdfd12558> in <module>()
----> 1 plt.scatter(training_features[:,0], training_features[:,1])

NameError: name 'training_features' is not defined

Initialize the classifier:

In [28]:
model = sklearn.neighbors.KNeighborsClassifier(n_neighbors=3)
model = sklearn.linear_model.LogisticRegression()

Perform 5-fold cross validation:

In [29]:
acc  = sklearn.cross_validation.cross_val_score(model, random_features, training_labels, cv=5)
In [30]:
print acc.mean()
0.5
In [ ]:
k