{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Computing Functions Across a Rolling Window of Time\n", "\n", "The following notebook demonstrates use of the `RollingWindow` network, which is a wrapper around `LinearNetwork` that sets `sys=PadeDelay(theta, order=dimensions)` and uses a few extra formulas / tricks. This network allows one to compute nonlinear functions over a finite rolling window of input history. It is most accurate for low-frequency inputs and for low-order nonlinearities. See [1] for details." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pylab inline\n", "import pylab\n", "try:\n", " import seaborn as sns # optional; prettier graphs\n", "except ImportError:\n", " pass\n", "\n", "import numpy as np\n", "\n", "import nengo\n", "import nengolib" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Setting up the network\n", "\n", "We first create a `RollingWindow` network with $\\theta=0.1\\,s$ corresponding to the size of the window in seconds, and pick some number of `LIFRate` neurons (`2000`). The order of the approximation (the dimension of the network's state ensemble) defaults to 6, since this was found to give the best fit to neural \"time cell\" data in rodents.\n", "\n", "We also need to create an input stimulus. Here, we use a band-limited white noise `process` since these methods are optimal for low-frequency inputs. Next, connect this to the `input` node within the rolling window network.\n", "\n", "Additionally, we provide this `process` to the network's constructor to optimize its evaluation points and encoders (during the build phase) for this particular process. Note that we do _not_ fix the seed of the process in order to prevent overfitting, but we do make the process long enough (`10` seconds) for it to generalize. This step is optional, but can dramatically improve the performance. If `process=None` the input should ideally be modelled, or the `eval_points`, (orthogonal) `encoders`, and `radii` should be manually specified." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "process = nengo.processes.WhiteSignal(10.0, high=15, y0=0)\n", "neuron_type = nengo.LIFRate() # try out LIF() or Direct()\n", "\n", "with nengolib.Network() as model:\n", " rw = nengolib.networks.RollingWindow(\n", " theta=0.1, n_neurons=2000, process=process, neuron_type=neuron_type)\n", "\n", " stim = nengo.Node(output=process)\n", "\n", " nengo.Connection(stim, rw.input, synapse=None)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 2. Decoding functions from the window\n", "\n", "Next we use the `add_output(...)` method to decode functions from the state ensemble. This method takes a `t` argument specifying the relative time-points of interest, and a `function` argument that specifies the ideal function to be computed along this window of points. The method returns a node which approximates this function from the window of input history.\n", "\n", "The `t` parameter can either be a single float, or an array of floats in the range $[0, 1]$. The size of `t` corresponds to the length of the window array ${\\bf w}$ passed to your function, and each element of the `t` array corresponds to the normalized delay in time for its respective point from the window. The `function` parameter must then accept a parameter `w` that is sized according to `t`, and should output the desired function from the given window `w`. Decoders will be optimized to approximate this function from the state of the rolling window network.\n", "\n", "For example:\n", "\n", "* `add_output(t=0, function=lambda w: w)` approximates a communication channel $f(x(t)) = x(t)$ (_Note_: this effectively undoes the filtering from the synapse!).\n", "* `add_output(t=1, function=lambda w: w**2)` approximates the function $f(x(t)) = x(t-\\theta)^2$.\n", "* `add_output(t=[.5, 1], function=lambda w: w[1] - w[0])` approximates the function $f(x(t)) = x(t-\\theta) - x(t-\\theta/2)$.\n", "\n", "By default, `t` will be `1000` points spaced evenly between `0` and `1`. For example:\n", "\n", "* `add_output(function=np.mean)` approximates the mean of this sampled window.\n", "* `add_output(function=np.median)` approximates a [median filter](https://en.wikipedia.org/wiki/Median_filter).\n", "* `add_output(function=np.max)` approximates the size of the largest peak.\n", "* `add_output(function=lambda w: np.argmax(w)/float(len(w)))` approximates how long ago the largest peak occured.\n", "\n", "The function can also return multiple dimensions.\n", "\n", "Here we compute two functions from the same state: (1) a delay of $\\theta$ seconds, and (2) the first four _moments_ of the window." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with model:\n", " delay = rw.output # equivalent to: rw.add_output(t=1)\n", "\n", " def compute_moments(w):\n", " \"\"\"Returns the first four moments of the window x.\"\"\"\n", " return np.mean(w), np.mean(w**2), np.mean(w**3), np.mean(w**4)\n", "\n", " moments = rw.add_output(function=compute_moments)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 3. Set up probes" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tau_probe = 0.01 # to filter the spikes\n", "\n", "with model:\n", " p_stim_unfiltered = nengo.Probe(stim, synapse=None)\n", "\n", " p_stim = nengo.Probe(stim, synapse=tau_probe) # filter for consistency\n", "\n", " p_delay = nengo.Probe(delay, synapse=tau_probe)\n", "\n", " p_moments = nengo.Probe(moments, synapse=tau_probe)\n", "\n", " p_x = nengo.Probe(rw.state, synapse=tau_probe) # for later analysis" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 4. Simulate the network" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with nengo.Simulator(model, seed=0) as sim:\n", " sim.run(1.0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 5. Plot results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Compute the ideal for comparison\n", "ideal = np.zeros_like(sim.data[p_moments])\n", "w = np.zeros(int(rw.theta / rw.dt))\n", "for i in range(len(ideal)):\n", " ideal[i] = compute_moments(w)\n", " w[0] = sim.data[p_stim_unfiltered][i]\n", " w = nengolib.signal.shift(w)\n", "ideal = nengolib.Lowpass(tau_probe).filt(ideal, dt=rw.dt, axis=0)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pylab.figure(figsize=(14, 4))\n", "pylab.title(\"Decoding a Delay\")\n", "pylab.plot(sim.trange(), sim.data[p_stim], label=\"Input\")\n", "pylab.plot(sim.trange(), sim.data[p_delay], label=\"Delay\")\n", "pylab.xlabel(\"Time (s)\")\n", "pylab.legend()\n", "pylab.show()\n", "\n", "fig, ax = pylab.subplots(p_moments.size_in, 1, figsize=(15, 8))\n", "for i in range(p_moments.size_in):\n", " error = nengolib.signal.nrmse(sim.data[p_moments][:, i], target=ideal[:, i])\n", " ax[i].set_title(r\"$\\mathbb{E} \\left[{\\bf w}^%d\\right]$\" % (i + 1))\n", " ax[i].plot(sim.trange(), sim.data[p_moments][:, i], label=\"Actual (NRMSE=%.2f)\" % error)\n", " ax[i].plot(sim.trange(), ideal[:, i], lw=3, linestyle='--', label=\"Expected\")\n", " ax[i].legend(loc='upper right', bbox_to_anchor=(1.20, 1), borderaxespad=0.)\n", "ax[-1].set_xlabel(\"Time (s)\")\n", "pylab.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pylab.figure(figsize=(14, 6))\n", "pylab.title(\"State Space\")\n", "pylab.plot(sim.trange(), sim.data[p_x])\n", "pylab.xlabel(\"Time (s)\")\n", "pylab.ylabel(r\"${\\bf x}$\")\n", "pylab.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Understanding the network\n", "\n", "This network essentially uses the `PadeDelay` system of order $d$ to compress the input into a $d$-dimensional state ${\\bf x}$. This state vector represents a rolling window of input history by a linear combination of $d$ basis functions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "B_canonical = rw.canonical_basis()\n", "\n", "pylab.figure()\n", "pylab.title(\"Canonical Basis\")\n", "pylab.plot(nengolib.networks.t_default, B_canonical)\n", "pylab.xlabel(\"Time (s)\")\n", "pylab.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "But since the state-space is transformed (by default it is a \"balanced realization\"), we have the following change of basis (by the linearly independent transformation `rw.realizer_result.T`):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "B = rw.basis()\n", "assert np.allclose(B_canonical.dot(rw.realizer_result.T), B)\n", "\n", "pylab.figure()\n", "pylab.title(\"Realized Basis\")\n", "pylab.plot(nengolib.networks.t_default, B)\n", "pylab.xlabel(\"Time (s)\")\n", "pylab.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the encoders of the network are axis-aligned (to improve accuracy of the linear system), this means that the function is able to accurately decode functions of the form:\n", "\n", "$$f \\left( \\sum_{i=1}^d x_i {\\bf u}_i \\right) = \\sum_{i=1}^{d} f_i ( x_i )$$\n", "\n", "where ${\\bf u}_i$ is the $i^\\texttt{th}$ basis function, $x_i$ is the corresponding weight given by the state vector ${\\bf x}$, and each $f_i$ is some nonlinear function supported by the neural tuning curves (typically a low-order polynomial).\n", "\n", "We now write ${\\bf w} = B {\\bf x} = \\sum_{i=1}^d x_i {\\bf u}_i$ where $B = \\left[ {\\bf u}_1 \\ldots {\\bf u}_d \\right]$ is our basis matrix, and ${\\bf w}$ is the window of history. Then the Moore-Penrose pseudoinverse $B^+ = (B^T B)^{-1} B^T$ gives us the relationship ${\\bf x} = B^+ {\\bf w}$, where $B^+ = \\left[ {\\bf v}_1 \\ldots {\\bf v}_d \\right]^T$ and ${\\bf v}_i $ can be called the $i^\\texttt{th}$ \"inverse basis function\". Finally, we can rewrite the computed function $f$ with respect to the window ${\\bf w}$ as:\n", "\n", "$$f ( {\\bf w} ) = \\sum_{i=1}^{d} f_i ( {\\bf v}_i \\cdot {\\bf w} )$$\n", "\n", "In other words, the functions that we can compute most accurately will be some linear combination of low-order nonlinearities applied to each ${\\bf v}_i \\cdot {\\bf w}$. Below we visualize each of these inverse basis function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pylab.figure()\n", "pylab.title(\"Inverse Basis Functions\")\n", "pylab.plot(nengolib.networks.t_default, rw.inverse_basis().T)\n", "pylab.xlabel(\"Time (s)\")\n", "pylab.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Since the basis functions for the balanced realization are nearly orthogonal, the inverse basis functions are approximately a rescaled version of the former." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Debugging issues in performance\n", "\n", "If the desired function is not accurate, then first look at the state-space to see if it is being represented correctly. If not (you might see erratic oscillations or saturation at large values), then there are a few specific things to try:\n", "\n", "1. Pass a more representative training `process`:\n", " - Make sure it corresponds to a typical input stimuli\n", " - Make it aperiodic over a longer time interval (at least 10 seconds)\n", " - Make the process contain higher frequencies (to \"activate\" all of the dimensions), or decrease the `dimensions`, or increase `theta`\n", " - Make the process contain lower frequencies (to put it within the range of the Padé approximants), or increase the `dimensions`, or decrease `theta`\n", " \n", "2. Pass `process=None`, and then:\n", " - Set `encoders` to be axis-aligned (`nengo.dists.Choice(np.vstack([I, -I]))`)\n", " - Set `radii` to the absolute maximum values of each dimension (`np.max(np.abs(x), axis=0)`, after realization)\n", " - Set `eval_points=nengolib.stats.cube` or to some representative points in state-space (after radii+realization)\n", "\n", "3. Change the solver and/or regularization:\n", " - Pass `solver=nengo.solvers.LstsqL2(reg=1e-X)` with different X ranging between `1` and `4`\n", " - Pass this to either `add_output` (to apply only to the decoded function), or to the constructor (to apply to both the recurrent function and the decoded function)\n", "\n", "Otherwise, if your function is not expressable as $\\sum_{i=1}^{d} f_i ( {\\bf v}_i \\cdot {\\bf w} )$ for the above ${\\bf v}_i$ and for low-order $f_i$, try:\n", "\n", "1. Pass a different state-space realization by providing a different `realizer` object to the `RollingWindow`. See `LinearNetwork` for details on realizers. This might rotate the state-space into the form of the above (analogous to how the nonlinear `Product` network is just a diagonal rotation of the linear `EnsembleArray`).\n", "\n", "2. Create a second ensemble with uniformly distributed encoders (or some other non-orthogonal distribution) and communicate the state variable to that ensemble. Then decode the desired function from that second ensemble using the above basis matrix to define the function with respect to `x`.\n", "\n", "3. For expert users, the `RollingWindow` is designed to be very \"hackable\", in that you can specify many of the `Ensemble` and `Connection` parameters needed to tweak performance, or customize how the `process` is used to solve for the `eval_points` and `encoders`, or even subclass what happens in `_make_core`. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "\n", "[1] Aaron R. Voelker and Chris Eliasmith. Improving spiking dynamical networks: accurate delays, higher-order synapses, and time cells. Neural Computation, 30(3):569-609, 03 2018. https://www.mitpressjournals.org/doi/abs/10.1162/neco_a_01046, doi:10.1162/neco_a_01046. [[GitHub](https://github.com/arvoelke/delay2017)]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\n", "\n" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }