{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Geometric Decoder Optimization\n", "\n", "This is a way to get an \"infinite\" number of evaluation points by computing the continuous versions of $\\Gamma = A^T A$ and $\\Upsilon = A^T f(x)$ that we normally use in Nengo. We do so for the scalar case and when $f(x)$ is a polynomial. The higher dimensional case requires more computational leg-work (Google integrating monomials over convex polytopes)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pylab inline\n", "import numpy as np\n", "import scipy.linalg\n", "import pylab\n", "try:\n", " import seaborn as sns # optional; prettier graphs\n", " edgecolors = sns.color_palette()[2]\n", "except ImportError:\n", " edgecolors = 'r'\n", "\n", "import nengo\n", "from nengo.neurons import RectifiedLinear, Sigmoid, LIFRate\n", "\n", "from nengolib.compat import get_activities" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Decoded Function\n", "\n", "(Limited to polynomials.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "identity = np.poly1d([1, 0]) # f(x) = 1x + 0\n", "square = np.poly1d([1, 0, 0]) # f(x) = 1x^2 + 0x + 0\n", "quartic = np.poly1d([1, -1, -1, 0, 0])\n", "function = identity" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Neuron Model\n", "\n", "(Limited to these three for now.)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#neuron_type = RectifiedLinear()\n", "#neuron_type = Sigmoid()\n", "neuron_type = LIFRate()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Baseline Decoders\n", "\n", "Let Nengo determine the gains / biases, given:\n", " - neuron model\n", " - number of neurons\n", " - seed\n", " \n", "And let Nengo solve for decoders (via MC sampling), given:\n", " - function\n", " - number of evaluation points\n", " - solver and regularization" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "n_neurons = 10\n", "n_eval_points = 50\n", "solver = nengo.solvers.LstsqL2(reg=0.01)\n", "tuning_seed = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with nengo.Network() as model:\n", " x = nengo.Ensemble(\n", " n_neurons, 1, neuron_type=neuron_type,\n", " n_eval_points=n_eval_points, seed=tuning_seed)\n", " \n", " conn = nengo.Connection(\n", " x, nengo.Node(size_in=1),\n", " function=lambda x: np.polyval(function, x), solver=solver)\n", " \n", "with nengo.Simulator(model) as sim: pass" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "eval_points = sim.data[x].eval_points\n", "e = sim.data[x].encoders.squeeze()\n", "gain = sim.data[x].gain\n", "bias = sim.data[x].bias\n", "intercepts = sim.data[x].intercepts\n", "\n", "if neuron_type == nengo.neurons.Sigmoid():\n", " # Hack to fix intercepts:\n", " # https://github.com/nengo/nengo/issues/1211\n", " intercepts = -np.ones_like(intercepts)\n", "\n", "d_alg = sim.data[conn].weights.T" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Refined Decoders" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "boundaries = e * intercepts\n", "on, off = [], []\n", "for i in range(n_neurons):\n", " on.append(-1. if e[i] < 0 else boundaries[i])\n", " off.append(1. if e[i] > 0 else boundaries[i])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Some useful helper functions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def dint(p, x1, x2):\n", " \"\"\"Computes `int_{x1}^{x2} p(x) dx` where `p` is a polynomial.\"\"\"\n", " return np.diff(np.polyval(np.polyint(p), [x1, x2]))\n", "\n", "\n", "def quadratic_taylor(g, dg, ddg):\n", " \"\"\"Returns a function that approximates g(ai*ei*x + bi) around x=y.\"\"\"\n", " def curve(i, y):\n", " j = gain[i] * e[i] * y + bias[i]\n", " f = g(j)\n", " df = gain[i] * e[i] * dg(j)\n", " ddf = (gain[i] * e[i])**2 * ddg(j)\n", " return np.poly1d([\n", " ddf / 2, df - y * ddf, f - y * df + y**2 * ddf / 2])\n", " return curve\n", "\n", "\n", "def segments(x1, x2, max_segments, min_width=0.05):\n", " \"\"\"Partitions [x1, x2] into segments (l, m, u) where m = (l + u) / 2.\"\"\"\n", " if x1 >= x2:\n", " return []\n", " n_segments = max(min(max_segments, int((x2 - x1) / min_width)), 1)\n", " r = np.zeros((n_segments, 3))\n", " r[:, 0] = np.arange(n_segments) * (x2 - x1) / n_segments + x1\n", " r[:, 2] = np.arange(1, n_segments + 1) * (x2 - x1) / n_segments + x1\n", " r[:, 1] = (r[:, 0] + r[:, 2]) / 2\n", " return r" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Approximate the neuron model using Taylor series polynomials." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [ "if neuron_type == nengo.neurons.RectifiedLinear():\n", " n_segments = 1\n", " def curve(i, _):\n", " return np.poly1d([gain[i] * e[i], bias[i]])\n", "\n", "elif neuron_type == nengo.neurons.Sigmoid():\n", " n_segments = min(n_eval_points, 50)\n", " ref = x.neuron_type.tau_ref\n", " g = lambda j: 1. / ref / (1 + np.exp(-j))\n", " dg = lambda j: np.exp(-j) / ref / (1 + np.exp(-j))**2\n", " ddg = lambda j: 2*np.exp(-2*j) / ref / (1 + np.exp(-j))**3 - dg(j)\n", " curve = quadratic_taylor(g, dg, ddg)\n", " \n", "elif neuron_type == nengo.neurons.LIFRate():\n", " n_segments = min(n_eval_points, 50)\n", " ref = x.neuron_type.tau_ref\n", " rc = x.neuron_type.tau_rc\n", " g = lambda j: 1. / (ref + rc * np.log1p(1 / (j - 1)))\n", " dg = lambda j: g(j)**2 * rc / j / (j - 1)\n", " ddg = lambda j: (g(j)**3 * rc * (2*rc + ref - 2*j*ref + \n", " (rc - 2*j*rc)*np.log1p(1 / (j - 1))) /\n", " j**2 / (j - 1)**2)\n", " curve = quadratic_taylor(g, dg, ddg)\n", "\n", "else:\n", " raise ValueError(\"Unsupported neuron type\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Determine a more accurate gamma (G) and upsilon (U) by integrating over the required polynomials. This can be made more efficient." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "G = np.zeros((n_neurons, n_neurons))\n", "U = np.zeros(n_neurons)\n", "\n", "for i, (li, ui) in enumerate(zip(on, off)):\n", " for x1, xm, x2 in segments(li, ui, n_segments):\n", " U[i] += dint(curve(i, xm) * function, x1, x2)\n", " for j, (lj, uj) in enumerate(zip(on, off)):\n", " for x1, xm, x2 in segments(max(li, lj), min(ui, uj), n_segments):\n", " G[i, j] += dint(curve(i, xm) * curve(j, xm), x1, x2)\n", "\n", "assert np.allclose(G.T, G)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Invert the gamma matrix and multiply by upsilon, as we normally do:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# d_geo = np.linalg.inv(G).dot(U)\n", "\n", "# More complicated decoder solver adapted from:\n", "# https://github.com/nengo/nengo/blob/84db35b5dd673ec715c4b11a0a9afae074f1895f/nengo/utils/least_squares_solvers.py#L32\n", "# in order to make comparisons fair with LstsqL2(reg=reg) where reg > 0.\n", "# Note this is not 'perfect' though because the test set might yield different effective regularization\n", "# than the entire integral. There is probably no way to have a perfect comparison.\n", "\n", "# Normalize G and U to be on par with the matrices used by Nengo\n", "# 2 = 1 - (-1) is volume of vector space\n", "G *= len(eval_points) / 2\n", "U *= len(eval_points) / 2\n", "\n", "A_nengo = get_activities(sim.model, x, eval_points)\n", "max_rate = np.max(A_nengo)\n", "sigma = solver.reg * max_rate \n", "m = len(eval_points)\n", "np.fill_diagonal(G, G.diagonal() + m * sigma**2)\n", "\n", "factor = scipy.linalg.cho_factor(G, overwrite_a=True)\n", "d_geo = scipy.linalg.cho_solve(factor, U)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot our segmentation of the tuning curve for debugging purposes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "i = 1\n", "\n", "x_test = np.linspace(-1, 1, 100000)\n", "x_test = x_test[x_test / e[i] > intercepts[i]]\n", "acts = get_activities(sim.model, x, x_test[:, None])\n", "\n", "pylab.figure()\n", "pylab.plot(x_test, acts[:, i], linestyle='--', label=\"Actual\")\n", "for j, (x1, xm, x2) in enumerate(segments(on[i], off[i], n_segments)):\n", " sl = (x_test > x1) & (x_test < x2)\n", " pylab.plot(x_test[sl], np.polyval(curve(i, xm), x_test[sl]),\n", " lw=2, alpha=0.8, label=\"%d\" % j)\n", "pylab.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "vertices = np.concatenate(([-1], np.sort(boundaries), [1]))\n", "\n", "for x_test, title in ((np.sort(eval_points.squeeze()), \"Training Data\"),\n", " (np.linspace(-1, 1, 100000), \"Test Data\")):\n", " y = conn.function(x_test)\n", " activities = get_activities(sim.model, x, x_test[:, None])\n", "\n", " pylab.figure()\n", " pylab.title(title)\n", " for d, label in ((d_alg, \"Algebraic\"),\n", " (d_geo, \"Geometric\")):\n", " y_hat = np.dot(activities, d).squeeze()\n", " percent_error = 100 * nengo.utils.numpy.rmse(y_hat, y) / nengo.utils.numpy.rms(y)\n", " pylab.plot(x_test, y_hat - y, label=\"%s (%.2f%%)\" % (label, percent_error))\n", " pylab.plot(x_test, np.zeros_like(x_test), lw=2, alpha=0.8, label=\"Ideal\")\n", " pylab.scatter(vertices, np.zeros_like(vertices), s=50, lw=2, facecolors='none',\n", " edgecolors=edgecolors, alpha=0.8, label=\"Vertices\")\n", " pylab.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)\n", " pylab.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "collapsed": true }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }