Commit dc9b8060 authored by Steve Tjoa's avatar Steve Tjoa

download=True option

parent 80f5cb83
......@@ -64,7 +64,7 @@ def get_features(collection='drum_samples_train',
return test_features, labels, scaler
def download_samples(collection='drum_samples_train'):
def download_samples(collection='drum_samples_train', download=True):
"""Download ten kick drum samples and ten snare drum samples.
`collection`: output directory containing the twenty drum samples
......@@ -80,23 +80,36 @@ def download_samples(collection='drum_samples_train'):
except OSError as exception:
if exception.errno != errno.EEXIST:
raise
else:
print "Directory %s already exists." % collection
if collection == 'drum_samples_train':
for drum_type in ['kick', 'snare']:
for i in range(1, 11):
filename = '%s_%02d.wav' % (drum_type, i)
urllib.urlretrieve('http://audio.musicinformationretrieval.com/drum_samples/%s' % filename,
filename=os.path.join(collection, filename))
if download:
for drum_type in ['kick', 'snare']:
for i in range(1, 11):
filename = '%s_%02d.wav' % (drum_type, i)
urllib.urlretrieve('http://audio.musicinformationretrieval.com/drum_samples/%s' % filename,
filename=os.path.join(collection, filename))
kick_filepaths = [os.path.join(collection, 'kick_%02d.wav' % i) for i in range(1, 11)]
snare_filepaths = [os.path.join(collection, 'snare_%02d.wav' % i) for i in range(1, 11)]
return kick_filepaths, snare_filepaths
elif collection == 'drum_samples_test':
for drum_type in ['kick', 'snare']:
for i in range(30):
filename = '%s_%02d.wav' % (drum_type, i)
urllib.urlretrieve('http://audio.musicinformationretrieval.com/drum_samples/test/%s' % filename,
filename=os.path.join(collection, filename))
if download:
for drum_type in ['kick', 'snare']:
for i in range(30):
filename = '%s_%02d.wav' % (drum_type, i)
urllib.urlretrieve('http://audio.musicinformationretrieval.com/drum_samples/test/%s' % filename,
filename=os.path.join(collection, filename))
kick_filepaths = [os.path.join(collection, 'kick_%02d.wav' % i) for i in range(30)]
snare_filepaths = [os.path.join(collection, 'snare_%02d.wav' % i) for i in range(30)]
return kick_filepaths, snare_filepaths
elif collection == 'violin_samples_train':
urllib.urlretrieve('http://audio.musicinformationretrieval.com/violin_samples_train/list.txt',
filename=os.path.join(collection, 'list.txt'))
for line in open(os.path.join(collection, 'list.txt'), 'r'):
filename = line.strip()
print filename
if filename.endswith('.wav'):
urllib.urlretrieve('http://audio.musicinformationretrieval.com/' + filename,
filename=filename)
return [os.path.join(collection, f) for f in os.listdir(collection)]
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment