full-FORCE and “Classic FORCE” learning with spikes

This notebook demonstrates how to implement both full-FORCE [1] and “Classic FORCE” [2] networks in Nengo. This makes it “trivial” to switch between neuron models (rate-based, spiking, adaptive, etc.), and to explore the effects of different learning rules and architectural assumptions.

For this demonstration, we use recursive least-squares (RLS) learning, with spiking LIF neurons, and the two basic architectures (full-FORCE and classic-FORCE) – to learn a bandpass filter (a.k.a. a “decaying oscillator” triggered by unit impulses).

[1]:
%pylab inline
import pylab
try:
    import seaborn as sns  # optional; prettier graphs
except ImportError:
    pass

import numpy as np
import nengo
import nengolib
from nengolib import RLS, Network
Populating the interactive namespace from numpy and matplotlib
[2]:
# Task parameters
pulse_interval = 1.0
amplitude = 0.1
freq = 3.0
decay = 2.0
dt = 0.002
trials_train = 3
trials_test = 2

# Fixed model parameters
n = 200
seed = 0
rng = np.random.RandomState(seed)
ens_kwargs = dict(  # neuron parameters
    n_neurons=n,
    dimensions=1,
    neuron_type=nengo.LIF(),  # nengolib.neurons.Tanh()
    intercepts=[-1]*n,  # intercepts are irelevant for Tanh
    seed=seed,
)

# Hyper-parameters
tau = 0.1                   # lowpass time-constant (10ms in [1])
tau_learn = 0.1             # filter for error / learning (needed for spiking)
tau_probe = 0.05            # filter for readout (needed for spiking
learning_rate = 0.1         # 1 in [1]
g = 1.5 / 400               # 1.5 in [1], scaled by firing rates
g_in = tau / amplitude      # scale the input encoders (usually 1)
g_out = 1.0                 # scale the recurrent encoders (usually 1)

# Pre-computed constants
T_train = trials_train * pulse_interval
T_total = (trials_train + trials_test) * pulse_interval
[3]:
with Network(seed=seed) as model:
    # Input is a pulse every pulse_interval seconds
    U = np.zeros(int(pulse_interval / dt))
    U[0] = amplitude / dt
    u = nengo.Node(output=nengo.processes.PresentInput(U, dt))

    # Desired output is a decaying oscillator
    z = nengo.Node(size_in=1)
    nengo.Connection(u, z, synapse=nengolib.synapses.Bandpass(freq, decay))

# Initial weights
e_in = g_in * rng.uniform(-1, +1, (n, 1))  # fixed encoders for f_in (u_in)
e_out = g_out * rng.uniform(-1, +1, (n, 1))  # fixed encoders for f_out (u)
JD = rng.randn(n, n) * g / np.sqrt(n)  # target-generating weights (variance g^2/n)

Classic FORCE

  • xC are the neurons
  • sC are the unfiltered currents into each neuron (sC -> Lowpass(tau) -> xC)
  • zC is the learned output estimate, decoded by the neurons, and re-encoded back into sC alongside some random feedback (JD)
  • eC is a gated error signal for RLS that turns off after T_train seconds. This error signal learns the feedback decoders by minmizing the difference between z (ideal output) and zC (actual output).

The error signal driving RLS has an additional filter applied (tau_learn) to handle the case when this signal consists of spikes (not rates).

[4]:
with model:
    xC = nengo.Ensemble(**ens_kwargs)
    sC = nengo.Node(size_in=n)  # pre filter
    eC = nengo.Node(size_in=1, output=lambda t, e: e if t < T_train else 0)
    zC = nengo.Node(size_in=1)  # learned output

    nengo.Connection(u, sC, synapse=None, transform=e_in)
    nengo.Connection(sC, xC.neurons, synapse=tau)
    nengo.Connection(xC.neurons, sC, synapse=None, transform=JD)  # chaos
    connC = nengo.Connection(
        xC.neurons, zC, synapse=None, transform=np.zeros((1, n)),
        learning_rule_type=RLS(learning_rate=learning_rate, pre_synapse=tau_learn))
    nengo.Connection(zC, sC, synapse=None, transform=e_out)

    nengo.Connection(zC, eC, synapse=None)  # actual
    nengo.Connection(z, eC, synapse=None, transform=-1)  # ideal
    nengo.Connection(eC, connC.learning_rule, synapse=tau_learn)

full-FORCE

full-FORCE Figure 1

Figure 1. Network architecture from [1].

Target-Generating Network

See Fig 1b.

  • xD are the neurons that behave like classic FORCE in the ideal case (assuming the ideal output z is perfectly re-encoded)
  • sD are the unfiltered currents into each neuron (sD -> Lowpass(tau) -> xD)
[5]:
with model:
    xD = nengo.Ensemble(**ens_kwargs)
    sD = nengo.Node(size_in=n)  # pre filter

    nengo.Connection(u, sD, synapse=None, transform=e_in)
    nengo.Connection(z, sD, synapse=None, transform=e_out)
    nengo.Connection(sD, xD.neurons, synapse=tau)
    nengo.Connection(xD.neurons, sD, synapse=None, transform=JD)

Task-Performing Network

See Fig 1a.

  • xF are the neurons
  • sF are the unfiltered currents into each neuron (sF -> Lowpass(tau) -> xF)
  • eF is a gated error signal for RLS that turns off after T_train seconds. This error signal learns the full-rank feedback weights by minimizing the difference between the unfiltered currents sD and sF.

The error signal driving RLS also has the same filter applied (tau_learn) to handle spikes. The output estimate is trained offline from the entire training set using batched least-squares, since this gives the best performance.

[6]:
with model:
    xF = nengo.Ensemble(**ens_kwargs)
    sF = nengo.Node(size_in=n)  # pre filter
    eF = nengo.Node(size_in=n, output=lambda t, e: e if t < T_train else np.zeros_like(e))

    nengo.Connection(u, sF, synapse=None, transform=e_in)
    nengo.Connection(sF, xF.neurons, synapse=tau)
    connF = nengo.Connection(
        xF.neurons, sF, synapse=None, transform=np.zeros((n, n)),
        learning_rule_type=RLS(learning_rate=learning_rate, pre_synapse=tau_learn))

    nengo.Connection(sF, eF, synapse=None)  # actual
    nengo.Connection(sD, eF, synapse=None, transform=-1)  # ideal
    nengo.Connection(eF, connF.learning_rule, synapse=tau_learn)

Results

[7]:
with model:
    # Probes
    p_z = nengo.Probe(z, synapse=tau_probe)
    p_zC = nengo.Probe(zC, synapse=tau_probe)
    p_xF = nengo.Probe(xF.neurons, synapse=tau_probe)

with nengo.Simulator(model, dt=dt) as sim:
    sim.run(T_total)
/home/arvoelke/anaconda3/envs/py36/lib/python3.6/site-packages/nengo/utils/numpy.py:79: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
  v = a[inds]
/home/arvoelke/CTN/nengolib/nengolib/signal/system.py:198: UserWarning: y0 (None!=0) does not properly initialize the system; see Nengo issue #1124.
  "Nengo issue #1124." % y0, UserWarning)
/home/arvoelke/CTN/nengolib/nengolib/signal/system.py:198: UserWarning: y0 (None!=0) does not properly initialize the system; see Nengo issue #1124.
  "Nengo issue #1124." % y0, UserWarning)
0%
 
/home/arvoelke/CTN/nengolib/nengolib/signal/system.py:198: UserWarning: y0 (None!=0) does not properly initialize the system; see Nengo issue #1124.
  "Nengo issue #1124." % y0, UserWarning)
0%
 
[8]:
# We do the readout training for full-FORCE offline, since this gives better
# performance without affecting anything else
t_train = sim.trange() < T_train
t_test = sim.trange() >= T_train

solver = nengo.solvers.LstsqL2(reg=1e-2)
wF, _ = solver(sim.data[p_xF][t_train], sim.data[p_z][t_train])
zF = sim.data[p_xF].dot(wF)

pylab.figure(figsize=(16, 6))
pylab.title("Training Output")
pylab.plot(sim.trange()[t_train], sim.data[p_zC][t_train], label="classic-FORCE")
pylab.plot(sim.trange()[t_train], zF[t_train], label="full-FORCE")
pylab.plot(sim.trange()[t_train], sim.data[p_z][t_train], label="Ideal", linestyle='--')
pylab.xlabel("Time (s)")
pylab.ylabel("Output")
pylab.legend()
pylab.show()

pylab.figure(figsize=(16, 6))
pylab.title("Testing Error")
pylab.plot(sim.trange()[t_test], sim.data[p_zC][t_test] - sim.data[p_z][t_test],
           alpha=0.8, label="classic-FORCE")
pylab.plot(sim.trange()[t_test], zF[t_test] - sim.data[p_z][t_test],
           alpha=0.8, label="full-FORCE")
pylab.xlabel("Time (s)")
pylab.ylabel("Error")
pylab.legend()
pylab.show()
../../_images/notebooks_examples_full_force_learning_12_0.png
../../_images/notebooks_examples_full_force_learning_12_1.png

References

[1] DePasquale, B., Cueva, C. J., Rajan, K., & Abbott, L. F. (2018). full-FORCE: A target-based method for training recurrent networks. PloS one, 13(2), e0191527. http://journals.plos.org/plosone/article?id=10.1371/journal.pone.0191527

[2] Sussillo, D., & Abbott, L. F. (2009). Generating coherent patterns of activity from chaotic neural networks. Neuron, 63(4), 544-557. https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2756108/

[9]: