{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# full-FORCE and \"Classic FORCE\" learning with spikes\n", "\n", "This notebook demonstrates how to implement both full-FORCE [1] and \"Classic FORCE\" [2] networks in Nengo. This makes it \"trivial\" to switch between neuron models (rate-based, spiking, adaptive, etc.), and to explore the effects of different learning rules and architectural assumptions.\n", "\n", "For this demonstration, we use recursive least-squares (RLS) learning, with spiking `LIF` neurons, and the two basic architectures (full-FORCE and classic-FORCE) -- to learn a bandpass filter (a.k.a. a \"decaying oscillator\" triggered by unit impulses)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pylab inline\n", "import pylab\n", "try:\n", " import seaborn as sns # optional; prettier graphs\n", "except ImportError:\n", " pass\n", "\n", "import numpy as np\n", "import nengo\n", "import nengolib\n", "from nengolib import RLS, Network" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Task parameters\n", "pulse_interval = 1.0\n", "amplitude = 0.1\n", "freq = 3.0\n", "decay = 2.0\n", "dt = 0.002\n", "trials_train = 3\n", "trials_test = 2\n", "\n", "# Fixed model parameters\n", "n = 200\n", "seed = 0\n", "rng = np.random.RandomState(seed)\n", "ens_kwargs = dict( # neuron parameters\n", " n_neurons=n,\n", " dimensions=1,\n", " neuron_type=nengo.LIF(), # nengolib.neurons.Tanh()\n", " intercepts=[-1]*n, # intercepts are irelevant for Tanh\n", " seed=seed,\n", ")\n", "\n", "# Hyper-parameters\n", "tau = 0.1 # lowpass time-constant (10ms in [1])\n", "tau_learn = 0.1 # filter for error / learning (needed for spiking)\n", "tau_probe = 0.05 # filter for readout (needed for spiking\n", "learning_rate = 0.1 # 1 in [1]\n", "g = 1.5 / 400 # 1.5 in [1], scaled by firing rates\n", "g_in = tau / amplitude # scale the input encoders (usually 1)\n", "g_out = 1.0 # scale the recurrent encoders (usually 1)\n", "\n", "# Pre-computed constants\n", "T_train = trials_train * pulse_interval\n", "T_total = (trials_train + trials_test) * pulse_interval" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with Network(seed=seed) as model:\n", " # Input is a pulse every pulse_interval seconds\n", " U = np.zeros(int(pulse_interval / dt))\n", " U[0] = amplitude / dt\n", " u = nengo.Node(output=nengo.processes.PresentInput(U, dt))\n", " \n", " # Desired output is a decaying oscillator\n", " z = nengo.Node(size_in=1)\n", " nengo.Connection(u, z, synapse=nengolib.synapses.Bandpass(freq, decay))\n", " \n", "# Initial weights\n", "e_in = g_in * rng.uniform(-1, +1, (n, 1)) # fixed encoders for f_in (u_in)\n", "e_out = g_out * rng.uniform(-1, +1, (n, 1)) # fixed encoders for f_out (u)\n", "JD = rng.randn(n, n) * g / np.sqrt(n) # target-generating weights (variance g^2/n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classic FORCE\n", "\n", " - `xC` are the neurons\n", " - `sC` are the unfiltered currents into each neuron (`sC -> Lowpass(tau) -> xC`)\n", " - `zC` is the learned output estimate, decoded by the neurons, and re-encoded back into `sC` alongside some random feedback (`JD`)\n", " - `eC` is a gated error signal for RLS that turns off after `T_train` seconds. This error signal learns the feedback decoders by minmizing the difference between `z` (ideal output) and `zC` (actual output).\n", "\n", "The error signal driving RLS has an additional filter applied (`tau_learn`) to handle the case when this signal consists of spikes (not rates)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with model:\n", " xC = nengo.Ensemble(**ens_kwargs)\n", " sC = nengo.Node(size_in=n) # pre filter\n", " eC = nengo.Node(size_in=1, output=lambda t, e: e if t < T_train else 0)\n", " zC = nengo.Node(size_in=1) # learned output\n", "\n", " nengo.Connection(u, sC, synapse=None, transform=e_in)\n", " nengo.Connection(sC, xC.neurons, synapse=tau)\n", " nengo.Connection(xC.neurons, sC, synapse=None, transform=JD) # chaos\n", " connC = nengo.Connection(\n", " xC.neurons, zC, synapse=None, transform=np.zeros((1, n)),\n", " learning_rule_type=RLS(learning_rate=learning_rate, pre_synapse=tau_learn))\n", " nengo.Connection(zC, sC, synapse=None, transform=e_out)\n", "\n", " nengo.Connection(zC, eC, synapse=None) # actual\n", " nengo.Connection(z, eC, synapse=None, transform=-1) # ideal\n", " nengo.Connection(eC, connC.learning_rule, synapse=tau_learn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## full-FORCE\n", "\n", "\n", "![full-FORCE Figure 1](http://journals.plos.org/plosone/article/figure/image?size=large&id=info:doi/10.1371/journal.pone.0191527.g001)\n", "
Figure 1. Network architecture from [1].
\n", "\n", "### Target-Generating Network\n", "\n", "See Fig 1b.\n", "\n", " - `xD` are the neurons that behave like classic FORCE in the ideal case (assuming the ideal output `z` is perfectly re-encoded)\n", " - `sD` are the unfiltered currents into each neuron (`sD -> Lowpass(tau) -> xD`)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with model:\n", " xD = nengo.Ensemble(**ens_kwargs)\n", " sD = nengo.Node(size_in=n) # pre filter\n", "\n", " nengo.Connection(u, sD, synapse=None, transform=e_in)\n", " nengo.Connection(z, sD, synapse=None, transform=e_out)\n", " nengo.Connection(sD, xD.neurons, synapse=tau)\n", " nengo.Connection(xD.neurons, sD, synapse=None, transform=JD)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Task-Performing Network\n", "\n", "See Fig 1a.\n", "\n", " - `xF` are the neurons\n", " - `sF` are the unfiltered currents into each neuron (`sF -> Lowpass(tau) -> xF`)\n", " - `eF` is a gated error signal for RLS that turns off after `T_train` seconds. This error signal learns the full-rank feedback weights by minimizing the difference between the unfiltered currents `sD` and `sF`.\n", "\n", "The error signal driving RLS also has the same filter applied (`tau_learn`) to handle spikes. The output estimate is trained offline from the entire training set using batched least-squares, since this gives the best performance." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with model:\n", " xF = nengo.Ensemble(**ens_kwargs)\n", " sF = nengo.Node(size_in=n) # pre filter\n", " eF = nengo.Node(size_in=n, output=lambda t, e: e if t < T_train else np.zeros_like(e))\n", "\n", " nengo.Connection(u, sF, synapse=None, transform=e_in)\n", " nengo.Connection(sF, xF.neurons, synapse=tau)\n", " connF = nengo.Connection(\n", " xF.neurons, sF, synapse=None, transform=np.zeros((n, n)),\n", " learning_rule_type=RLS(learning_rate=learning_rate, pre_synapse=tau_learn))\n", "\n", " nengo.Connection(sF, eF, synapse=None) # actual\n", " nengo.Connection(sD, eF, synapse=None, transform=-1) # ideal\n", " nengo.Connection(eF, connF.learning_rule, synapse=tau_learn) " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Results" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with model:\n", " # Probes\n", " p_z = nengo.Probe(z, synapse=tau_probe)\n", " p_zC = nengo.Probe(zC, synapse=tau_probe)\n", " p_xF = nengo.Probe(xF.neurons, synapse=tau_probe)\n", "\n", "with nengo.Simulator(model, dt=dt) as sim:\n", " sim.run(T_total)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# We do the readout training for full-FORCE offline, since this gives better\n", "# performance without affecting anything else\n", "t_train = sim.trange() < T_train\n", "t_test = sim.trange() >= T_train\n", "\n", "solver = nengo.solvers.LstsqL2(reg=1e-2)\n", "wF, _ = solver(sim.data[p_xF][t_train], sim.data[p_z][t_train])\n", "zF = sim.data[p_xF].dot(wF)\n", "\n", "pylab.figure(figsize=(16, 6))\n", "pylab.title(\"Training Output\")\n", "pylab.plot(sim.trange()[t_train], sim.data[p_zC][t_train], label=\"classic-FORCE\")\n", "pylab.plot(sim.trange()[t_train], zF[t_train], label=\"full-FORCE\")\n", "pylab.plot(sim.trange()[t_train], sim.data[p_z][t_train], label=\"Ideal\", linestyle='--')\n", "pylab.xlabel(\"Time (s)\")\n", "pylab.ylabel(\"Output\")\n", "pylab.legend()\n", "pylab.show()\n", "\n", "pylab.figure(figsize=(16, 6))\n", "pylab.title(\"Testing Error\")\n", "pylab.plot(sim.trange()[t_test], sim.data[p_zC][t_test] - sim.data[p_z][t_test],\n", " alpha=0.8, label=\"classic-FORCE\")\n", "pylab.plot(sim.trange()[t_test], zF[t_test] - sim.data[p_z][t_test],\n", " alpha=0.8, label=\"full-FORCE\")\n", "pylab.xlabel(\"Time (s)\")\n", "pylab.ylabel(\"Error\")\n", "pylab.legend()\n", "pylab.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## References\n", "\n", "[1] DePasquale, B., Cueva, C. J., Rajan, K., & Abbott, L. F. (2018). full-FORCE: A target-based method for training recurrent networks. PloS one, 13(2), e0191527. http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0191527\n", "\n", "[2] Sussillo, D., & Abbott, L. F. (2009). Generating coherent patterns of activity from chaotic neural networks. Neuron, 63(4), 544-557. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2756108/" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }