{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# full-FORCE and \"Classic FORCE\" learning with spikes\n", "\n", "This notebook demonstrates how to implement both full-FORCE [1] and \"Classic FORCE\" [2] networks in Nengo. This makes it \"trivial\" to switch between neuron models (rate-based, spiking, adaptive, etc.), and to explore the effects of different learning rules and architectural assumptions.\n", "\n", "For this demonstration, we use recursive least-squares (RLS) learning, with spiking `LIF` neurons, and the two basic architectures (full-FORCE and classic-FORCE) -- to learn a bandpass filter (a.k.a. a \"decaying oscillator\" triggered by unit impulses)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pylab inline\n", "import pylab\n", "try:\n", " import seaborn as sns # optional; prettier graphs\n", "except ImportError:\n", " pass\n", "\n", "import numpy as np\n", "import nengo\n", "import nengolib\n", "from nengolib import RLS, Network" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Task parameters\n", "pulse_interval = 1.0\n", "amplitude = 0.1\n", "freq = 3.0\n", "decay = 2.0\n", "dt = 0.002\n", "trials_train = 3\n", "trials_test = 2\n", "\n", "# Fixed model parameters\n", "n = 200\n", "seed = 0\n", "rng = np.random.RandomState(seed)\n", "ens_kwargs = dict( # neuron parameters\n", " n_neurons=n,\n", " dimensions=1,\n", " neuron_type=nengo.LIF(), # nengolib.neurons.Tanh()\n", " intercepts=[-1]*n, # intercepts are irelevant for Tanh\n", " seed=seed,\n", ")\n", "\n", "# Hyper-parameters\n", "tau = 0.1 # lowpass time-constant (10ms in [1])\n", "tau_learn = 0.1 # filter for error / learning (needed for spiking)\n", "tau_probe = 0.05 # filter for readout (needed for spiking\n", "learning_rate = 0.1 # 1 in [1]\n", "g = 1.5 / 400 # 1.5 in [1], scaled by firing rates\n", "g_in = tau / amplitude # scale the input encoders (usually 1)\n", "g_out = 1.0 # scale the recurrent encoders (usually 1)\n", "\n", "# Pre-computed constants\n", "T_train = trials_train * pulse_interval\n", "T_total = (trials_train + trials_test) * pulse_interval" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with Network(seed=seed) as model:\n", " # Input is a pulse every pulse_interval seconds\n", " U = np.zeros(int(pulse_interval / dt))\n", " U[0] = amplitude / dt\n", " u = nengo.Node(output=nengo.processes.PresentInput(U, dt))\n", " \n", " # Desired output is a decaying oscillator\n", " z = nengo.Node(size_in=1)\n", " nengo.Connection(u, z, synapse=nengolib.synapses.Bandpass(freq, decay))\n", " \n", "# Initial weights\n", "e_in = g_in * rng.uniform(-1, +1, (n, 1)) # fixed encoders for f_in (u_in)\n", "e_out = g_out * rng.uniform(-1, +1, (n, 1)) # fixed encoders for f_out (u)\n", "JD = rng.randn(n, n) * g / np.sqrt(n) # target-generating weights (variance g^2/n)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Classic FORCE\n", "\n", " - `xC` are the neurons\n", " - `sC` are the unfiltered currents into each neuron (`sC -> Lowpass(tau) -> xC`)\n", " - `zC` is the learned output estimate, decoded by the neurons, and re-encoded back into `sC` alongside some random feedback (`JD`)\n", " - `eC` is a gated error signal for RLS that turns off after `T_train` seconds. This error signal learns the feedback decoders by minmizing the difference between `z` (ideal output) and `zC` (actual output).\n", "\n", "The error signal driving RLS has an additional filter applied (`tau_learn`) to handle the case when this signal consists of spikes (not rates)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with model:\n", " xC = nengo.Ensemble(**ens_kwargs)\n", " sC = nengo.Node(size_in=n) # pre filter\n", " eC = nengo.Node(size_in=1, output=lambda t, e: e if t < T_train else 0)\n", " zC = nengo.Node(size_in=1) # learned output\n", "\n", " nengo.Connection(u, sC, synapse=None, transform=e_in)\n", " nengo.Connection(sC, xC.neurons, synapse=tau)\n", " nengo.Connection(xC.neurons, sC, synapse=None, transform=JD) # chaos\n", " connC = nengo.Connection(\n", " xC.neurons, zC, synapse=None, transform=np.zeros((1, n)),\n", " learning_rule_type=RLS(learning_rate=learning_rate, pre_synapse=tau_learn))\n", " nengo.Connection(zC, sC, synapse=None, transform=e_out)\n", "\n", " nengo.Connection(zC, eC, synapse=None) # actual\n", " nengo.Connection(z, eC, synapse=None, transform=-1) # ideal\n", " nengo.Connection(eC, connC.learning_rule, synapse=tau_learn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## full-FORCE\n", "\n", "\n", "\n", "