Commit a7f5c412 authored by Steve Tjoa's avatar Steve Tjoa

stanford_mir.init()

parent 15c3d2b2
import errno import errno
import librosa import librosa
import matplotlib, matplotlib.pyplot as plt
import numpy import numpy
import os import os
import os.path import os.path
import sklearn import sklearn
import urllib import urllib.request
def init():
plt.style.use('seaborn-muted')
#plt.rcParams['figure.figsize'] = (14, 5)
plt.rcParams['axes.grid'] = True
plt.rcParams['axes.spines.left'] = False
plt.rcParams['axes.spines.right'] = False
plt.rcParams['axes.spines.bottom'] = False
plt.rcParams['axes.spines.top'] = False
plt.rcParams['axes.xmargin'] = 0
plt.rcParams['axes.ymargin'] = 0
plt.rcParams['image.cmap'] = 'gray'
plt.rcParams['image.interpolation'] = None
def extract_features(signal, features): def extract_features(signal, features):
fvs = list() fvs = list()
...@@ -15,6 +30,7 @@ def extract_features(signal, features): ...@@ -15,6 +30,7 @@ def extract_features(signal, features):
fvs.append( librosa.feature.spectral_centroid(signal)[0, 0] ) fvs.append( librosa.feature.spectral_centroid(signal)[0, 0] )
return fvs return fvs
def get_features(collection='drum_samples_train', def get_features(collection='drum_samples_train',
features=('zero_crossing_rate', 'spectral_centroid'), features=('zero_crossing_rate', 'spectral_centroid'),
scaler=None, scaler=None,
...@@ -88,7 +104,7 @@ def download_samples(collection='drum_samples_train', download=True): ...@@ -88,7 +104,7 @@ def download_samples(collection='drum_samples_train', download=True):
for drum_type in ['kick', 'snare']: for drum_type in ['kick', 'snare']:
for i in range(1, 11): for i in range(1, 11):
filename = '%s_%02d.wav' % (drum_type, i) filename = '%s_%02d.wav' % (drum_type, i)
urllib.urlretrieve('http://audio.musicinformationretrieval.com/drum_samples/%s' % filename, urllib.request.urlretrieve('http://audio.musicinformationretrieval.com/drum_samples/%s' % filename,
filename=os.path.join(collection, filename)) filename=os.path.join(collection, filename))
kick_filepaths = [os.path.join(collection, 'kick_%02d.wav' % i) for i in range(1, 11)] kick_filepaths = [os.path.join(collection, 'kick_%02d.wav' % i) for i in range(1, 11)]
snare_filepaths = [os.path.join(collection, 'snare_%02d.wav' % i) for i in range(1, 11)] snare_filepaths = [os.path.join(collection, 'snare_%02d.wav' % i) for i in range(1, 11)]
...@@ -99,19 +115,19 @@ def download_samples(collection='drum_samples_train', download=True): ...@@ -99,19 +115,19 @@ def download_samples(collection='drum_samples_train', download=True):
for drum_type in ['kick', 'snare']: for drum_type in ['kick', 'snare']:
for i in range(30): for i in range(30):
filename = '%s_%02d.wav' % (drum_type, i) filename = '%s_%02d.wav' % (drum_type, i)
urllib.urlretrieve('http://audio.musicinformationretrieval.com/drum_samples/test/%s' % filename, urllib.request.urlretrieve('http://audio.musicinformationretrieval.com/drum_samples/test/%s' % filename,
filename=os.path.join(collection, filename)) filename=os.path.join(collection, filename))
kick_filepaths = [os.path.join(collection, 'kick_%02d.wav' % i) for i in range(30)] kick_filepaths = [os.path.join(collection, 'kick_%02d.wav' % i) for i in range(30)]
snare_filepaths = [os.path.join(collection, 'snare_%02d.wav' % i) for i in range(30)] snare_filepaths = [os.path.join(collection, 'snare_%02d.wav' % i) for i in range(30)]
return kick_filepaths, snare_filepaths return kick_filepaths, snare_filepaths
elif collection == 'violin_samples_train': elif collection == 'violin_samples_train':
urllib.urlretrieve('http://audio.musicinformationretrieval.com/violin_samples_train/list.txt', urllib.request.urlretrieve('http://audio.musicinformationretrieval.com/violin_samples_train/list.txt',
filename=os.path.join(collection, 'list.txt')) filename=os.path.join(collection, 'list.txt'))
for line in open(os.path.join(collection, 'list.txt'), 'r'): for line in open(os.path.join(collection, 'list.txt'), 'r'):
filename = line.strip() filename = line.strip()
print filename print(filename)
if filename.endswith('.wav'): if filename.endswith('.wav'):
urllib.urlretrieve('http://audio.musicinformationretrieval.com/' + filename, urllib.request.urlretrieve('http://audio.musicinformationretrieval.com/' + filename,
filename=filename) filename=filename)
return [os.path.join(collection, f) for f in os.listdir(collection)] return [os.path.join(collection, f) for f in os.listdir(collection)]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment