Commit 8f27d159 authored by Steve Tjoa's avatar Steve Tjoa

adding lsh music fingerprinting

parent 475783e8
{
"metadata": {
"name": "",
"signature": "sha256:1e57ed1516a61260e593f53b74843c63ebc0cf1bf4a9edcdd0616da4d46b32be"
},
"nbformat": 3,
"nbformat_minor": 0,
"worksheets": [
{
"cells": [
{
"cell_type": "heading",
"level": 2,
"metadata": {},
"source": [
"Import libraries:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"import essentia.standard as ess\n",
"import os\n",
"import os.path"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 1
},
{
"cell_type": "heading",
"level": 2,
"metadata": {},
"source": [
"Select training data:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"training_dir = '../train/'\n",
"training_files = [os.path.join(training_dir, f) for f in os.listdir(training_dir)]"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 79
},
{
"cell_type": "heading",
"level": 2,
"metadata": {},
"source": [
"Define a hash function:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def hash_func(vec, projections):\n",
" bools = dot(projections, vec) > 0\n",
" return bool2int(bools)"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 80
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"def bool2int(x):\n",
" y = 0\n",
" for i,j in enumerate(x):\n",
" if j: y += 1<<i\n",
" return y"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 81
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"projections = randn(5, 512)\n",
"x = randn(512)\n",
"hash_func(x, projections)"
],
"language": "python",
"metadata": {},
"outputs": [
{
"metadata": {},
"output_type": "pyout",
"prompt_number": 82,
"text": [
"26"
]
}
],
"prompt_number": 82
},
{
"cell_type": "heading",
"level": 2,
"metadata": {},
"source": [
"Create three LSH structures: Table, LSH, and MusicSearch:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"class Table:\n",
" \n",
" def __init__(self, hash_size, dim):\n",
" self.table = dict()\n",
" self.hash_size = hash_size\n",
" self.dim = dim # TODO is this necessary?\n",
" self.projections = randn(self.hash_size, self.dim)\n",
"\n",
" def add(self, vec, label):\n",
" entry = {'vector': None, 'label': label}\n",
" h = hash_func(vec, self.projections)\n",
" if self.table.has_key(h):\n",
" self.table[h].append(entry)\n",
" else:\n",
" self.table[h] = [entry]\n",
"\n",
" def query(self, vec):\n",
" h = hash_func(vec, self.projections)\n",
" if self.table.has_key(h):\n",
" results = self.table[h]\n",
" else:\n",
" results = list()\n",
" return results"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 83
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": []
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"class LSH:\n",
" \n",
" def __init__(self, dim):\n",
" self.num_tables = 7\n",
" self.hash_size = 10\n",
" self.dim = dim\n",
" self.tables = list()\n",
" for i in range(self.num_tables):\n",
" self.tables.append(Table(self.hash_size, self.dim))\n",
" \n",
" def add(self, vec, label):\n",
" for table in self.tables:\n",
" table.add(vec, label)\n",
" \n",
" def query(self, vec):\n",
" results = list()\n",
" for table in self.tables:\n",
" results.extend(table.query(vec))\n",
" return results\n",
"\n",
" def describe(self):\n",
" for table in self.tables:\n",
" print table.table"
],
"language": "python",
"metadata": {},
"outputs": [],
"prompt_number": 84
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"class MusicSearch:\n",
" \n",
" def __init__(self, training_files):\n",
" self.frame_size = 4096\n",
" self.hop_size = 4000\n",
" self.fv_size = 1000\n",
" self.lsh = LSH(self.fv_size)\n",
" self.training_files = training_files\n",
" self.num_features_in_file = dict()\n",
" for f in self.training_files:\n",
" self.num_features_in_file[f] = 0\n",
" \n",
" def get_features(self, frame):\n",
" hamming_window = ess.Windowing(type='hamming')\n",
" spectrum = ess.Spectrum()\n",
" return spectrum(hamming_window(frame))[:self.fv_size]\n",
" \n",
" def train(self):\n",
" for filepath in self.training_files:\n",
" x = ess.MonoLoader(filename=filepath)()\n",
" for frame in ess.FrameGenerator(x, frameSize=self.frame_size, hopSize=self.hop_size):\n",
" self.lsh.add(self.get_features(frame), filepath)\n",
" self.num_features_in_file[filepath] += 1\n",
" \n",
" def query(self, filepath):\n",
" x = ess.MonoLoader(filename=filepath)()\n",
" features = [self.get_features(frame) \n",
" for frame in ess.FrameGenerator(x, frameSize=self.frame_size, hopSize=self.hop_size)]\n",
" results = list()\n",
" for vec in features:\n",
" results.extend(self.lsh.query(vec))\n",
"\n",
" counts = dict()\n",
" for r in results:\n",
" if counts.has_key(r['label']):\n",
" counts[r['label']] += 1\n",
" else:\n",
" counts[r['label']] = 1\n",
" for k in counts:\n",
" counts[k] = float(counts[k])/self.num_features_in_file[k]\n",
" return counts\n"
],
"language": "python",
"metadata": {},
"outputs": []
},
{
"cell_type": "heading",
"level": 2,
"metadata": {},
"source": [
"Train:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"ms = MusicSearch(training_files)\n",
"ms.train()"
],
"language": "python",
"metadata": {},
"outputs": []
},
{
"cell_type": "heading",
"level": 2,
"metadata": {},
"source": [
"Test:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"test_file = '../test/steve_bach_p3.wav'\n",
"results = ms.query(test_file)"
],
"language": "python",
"metadata": {},
"outputs": []
},
{
"cell_type": "heading",
"level": 2,
"metadata": {},
"source": [
"Display the results:"
]
},
{
"cell_type": "code",
"collapsed": false,
"input": [
"for r in sorted(results, key=results.get, reverse=True):\n",
" print r, results[r]"
],
"language": "python",
"metadata": {},
"outputs": []
},
{
"cell_type": "code",
"collapsed": false,
"input": [],
"language": "python",
"metadata": {},
"outputs": []
}
],
"metadata": {}
}
]
}
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment