nengolib.learning.RLS

class nengolib.learning.RLS(learning_rate=1.0, pre_synapse=Lowpass(0.005))[source]

Recursive least-squares rule for online decoder optimization.

This may be used to learn the weights on a nengo.Connection, online, in an L2-optimal manner. To be applied in the same scenarios as nengo.PES, to minimize some error signal.

In the end, the only real difference between RLS learning and using the nengo.solvers.LstsqL2 solver, is when the learning takes place. In the former case, the weights are learned online from an error signal over time, whereas in the latter case, the weights are learned offline in a batch optimization from the provided training data (eval_points and function).

The cost of RLS is \(\mathcal{O}\left(n^2\right)\) extra time and memory. It is typically much more efficient to do the learning offline using the nengo.solvers.LstsqL2 solver.

Parameters:
learning_rate : float, optional

Effective learning rate. This is better understood as \(\frac{1}{\alpha}\), where \(\alpha\) is an L2-regularization term. A large learning rate means little regularization, which implies quick over-fitting. A small learning rate means large regularization, which translates to slower learning. Defaults to 1.0. [1]

pre_synapse : nengo.synapses.Synapse, optional

Filter applied to the pre-synaptic neural activities, for the purpose of applying the weight update. Defaults to a nengo.Lowpass filter with a time-constant of 5 ms.

Notes

RLS works by maintaining the inverse neural correlation matrix, \(\Gamma^{-1}\), where \(\Gamma = A^T A + \alpha I\) are the regularized correlations, \(A\) is a matrix of (possibly filtered) neural activities, and \(\alpha\) is an L2-regularization term controlled by the learning_rate. This matrix is used to project the error signal and update the weights to be L2-optimal, at each time-step.

The time-step does not play a role in this learning rule, apart from determining the time-scale over which the pre_synapse is discretized. A complete learning update is applied on every time-step.

Attributes that can be probed from this learning rule: pre_filtered, error, delta, inv_gamma.

References

[1]Sussillo, D., & Abbott, L. F. (2009). Generating coherent patterns of activity from chaotic neural networks. Neuron, 63(4), 544-557.

Examples

See full-FORCE and “Classic FORCE” learning with spikes for an example of how to use RLS to learn spiking FORCE [1] and “full-FORCE” networks in Nengo.

Below, we compare nengo.PES against RLS, learning a feed-forward communication channel (identity function), online, and starting with 100 spiking LIF neurons from scratch (zero weights). A faster learning rate for nengo.PES results in over-fitting to the most recent online example, while a slower learning rate does not learn quickly enough. This is a general problem with greedy optimization. RLS performs better since it is L2-optimal.

>>> from nengolib import RLS, Network
>>> import nengo
>>> from nengo import PES
>>> tau = 0.005
>>> learning_rules = (PES(learning_rate=1e-3, pre_tau=tau),
>>>                   RLS(learning_rate=1e-5, pre_synapse=tau))
>>> with Network() as model:
>>>     u = nengo.Node(output=lambda t: np.sin(2*np.pi*t))
>>>     probes = []
>>>     for lr in learning_rules:
>>>         e = nengo.Node(size_in=1,
>>>                        output=lambda t, e: e if t < 1 else 0)
>>>         x = nengo.Ensemble(100, 1, seed=0)
>>>         y = nengo.Node(size_in=1)
>>>
>>>         nengo.Connection(u, e, synapse=None, transform=-1)
>>>         nengo.Connection(u, x, synapse=None)
>>>         conn = nengo.Connection(
>>>             x, y, synapse=None, learning_rule_type=lr,
>>>             function=lambda _: 0)
>>>         nengo.Connection(y, e, synapse=None)
>>>         nengo.Connection(e, conn.learning_rule, synapse=tau)
>>>         probes.append(nengo.Probe(y, synapse=tau))
>>>     probes.append(nengo.Probe(u, synapse=tau))
>>> with nengo.Simulator(model) as sim:
>>>     sim.run(2.0)
>>> import matplotlib.pyplot as plt
>>> plt.plot(sim.trange(), sim.data[probes[0]],
>>>          label=str(learning_rules[0]))
>>> plt.plot(sim.trange(), sim.data[probes[1]],
>>>          label=str(learning_rules[1]))
>>> plt.plot(sim.trange(), sim.data[probes[2]],
>>>          label="Ideal", linestyle='--')
>>> plt.vlines([1], -1, 1, label="Training -> Testing")
>>> plt.ylim(-2, 2)
>>> plt.legend(loc='upper right')
>>> plt.xlabel("Time (s)")
>>> plt.show()

(Source code)

_images/nengolib-learning-RLS-1.png
Attributes:
learning_rate

A parameter where the value is a number.

pre_synapse
size_in