Fitting an SNN to do a Classification Task#

In the example below, we will use gradient descent to fit the parameters of an SNN. Since Jaxley implements surrogate gradients for Fire channels by default, we can backpropagate through the entire network, despite the SNN implementing non-differentiable functions.

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import optax
from jax import config
from tqdm import tqdm

import jaxley as jx
from jaxley.channels import Fire, Leak
from jaxley.synapses import SpikeSynapse

# Set up configuration
config.update("jax_enable_x64", True)

dt = 1e-3
T = 0.20

reset_voltage = 0.0
tau_mem = 10e-3

# Define the initial parameters of the channels
params_init_Leak = {
    "Leak_gLeak": jnp.array(tau_mem),
    "Leak_eLeak": jnp.array(reset_voltage),
}

params_init_Fire = {
    "Fire_vth": jnp.array(1.0),
    "Fire_vreset": jnp.array(reset_voltage),
}

Simulation and network hyperparameters#

Set RNG seed, sizes of each layer, and build the network (input → hidden → output). Note that the output neurons are created without Fire channels because we will read their membrane potentials rather than their spikes for classification.

# Set seed and layer sizes
seed = 1701
key = jax.random.key(seed)
batch_size = 256
n_input = 100
n_hidden = 4
n_output = 2

# Create network
neurons = [jx.Compartment() for _ in range(n_input + n_hidden + n_output)]
net = jx.Network(neurons)

# Set the parameters of the neurons
net.select(nodes="all").insert(Leak())
net.select(nodes=range(n_input + n_hidden)).insert(Fire())
for param in params_init_Leak.keys():
    net.select(nodes="all").set(param, params_init_Leak[param])
for param in params_init_Fire.keys():
    net.select(nodes=range(n_input + n_hidden)).set(param, params_init_Fire[param])
net.select(nodes="all").set("length", 1.0 / (2 * jnp.pi * 1e-5))
net.select(nodes="all").set("radius", 1.0)  # 1.0 is also the default.
net.select(nodes="all").set("v", 0.0)

# index ranges of layers in network
hidden_end = n_input + n_hidden
output_end = hidden_end + n_output
input_range = range(n_input)
hidden_range = range(n_input, hidden_end)
output_range = range(hidden_end, output_end)

# connect neurons
jx.fully_connect(net.cell(input_range), net.cell(hidden_range), SpikeSynapse())
jx.fully_connect(net.cell(hidden_range), net.cell(output_range), SpikeSynapse())


def normal(subkey, mean, std, shape):
    return mean + (jax.random.normal(subkey, shape) * std)


# Set random synaptic weights
parameters = None
l1_edges = n_input * n_hidden
l2_edges = n_hidden * n_output
key, subkey = jax.random.split(key)
gs_l1 = normal(subkey, 0.0, 10.0, (l1_edges))
gs_l1 = [{"SpikeSynapse_gS": gs_l1}]
key, subkey = jax.random.split(key)
gs_l2 = normal(subkey, 0.0, 10.0, (l2_edges))
gs_l2 = [{"SpikeSynapse_gS": gs_l2}]

for edge in range(l1_edges):
    net.select(edges=edge).set("SpikeSynapse_gS", gs_l1[0]["SpikeSynapse_gS"][edge])

for edge in range(l2_edges):
    net.select(edges=(l1_edges + edge)).set(
        "SpikeSynapse_gS", gs_l2[0]["SpikeSynapse_gS"][edge]
    )

net.select(edges="all").make_trainable("SpikeSynapse_gS")

parameters = net.get_parameters()
Number of newly added trainable parameters: 408. Total number of trainable parameters: 408

Creating input currents / dataset#

For this toy classification task we sample Poisson spike trains for the inputs and convolve them with a small kernel (i_min) to create continuous input currents. We then pack batch_size such examples into current_batches which will be used for training.

i_min = jnp.array([508.0] * 2)
freq = 5  # Hz
prob = freq * dt
key, subkey = jax.random.split(key)
mask = jax.random.uniform(subkey, (batch_size, n_input, int(T / dt)))
x_data = jnp.where(mask < prob, 1.0, 0.0)

current_batches = jnp.array(
    [
        [jnp.convolve(i_min, x_data[i][j], "same") for j in range(n_input)]
        for i in range(batch_size)
    ]
)

net.delete_recordings()
net.cell(output_range).record("v")
Added 2 recordings. See `.recordings` for details.

Simulation helpers#

simulate_network runs a single simulation and returns the voltage traces of the output neurons. We JIT-compile and vectorize over batches with simulate_network_batches so training is faster.

def simulate_network(params, current):
    """Simulates the network, returning the voltage traces of the output neurons"""
    data_stimuli = net.cell(range(n_input)).data_stimulate(current, None)
    return jx.integrate(
        net,
        data_stimuli=data_stimuli,
        params=params,
        delta_t=dt,
        t_max=T,
    )


simulate_network_batches = jax.jit(jax.vmap(simulate_network, in_axes=(None, 0)))

Loss and accuracy#

We compute the maximum membrane potential across time for each output neuron and apply a softmax to those maxima to obtain class log-probabilities. The loss is the standard cross-entropy between the predicted log-probabilities and one-hot targets. The get_accuracy helper computes a simple classification accuracy using argmax of the maxima.

def cross_entropy(logprobs, targets):
    target_class = targets
    nll = jnp.take_along_axis(logprobs, jnp.expand_dims(target_class, axis=1), axis=1)
    ce = -jnp.mean(nll)
    return ce


def loss_fn(params, targets):
    voltages = simulate_network_batches(params, current_batches)
    maximums = jnp.max(voltages, axis=2)
    log_p_y = jax.nn.log_softmax(maximums, axis=1)
    return cross_entropy(log_p_y, targets)


key, subkey = jax.random.split(key)
targets = jnp.where(
    jax.random.uniform(subkey, (batch_size), minval=0.0, maxval=1.0) < 0.5, 1, 0
)


def get_accuracy(params):
    voltages = simulate_network_batches(params, current_batches)
    maximums = jnp.max(voltages, axis=2)
    am = jnp.argmax(maximums, axis=1)  # argmax over output units
    acc = jnp.mean((targets == am))  # compare to labels
    return acc

Optimization and training#

We use optax.adam to optimize the synaptic weights (which we flagged as trainable earlier). The training loop JIT-compiles the gradient computation and updates the parameters on each epoch. We also display a progress bar using tqdm with the current loss and accuracy.

grad_fn = jax.jit(jax.value_and_grad(loss_fn, argnums=0))
optimizer = optax.adam(learning_rate=0.1)
opt_params = parameters
opt_state = optimizer.init(opt_params)

n_epochs = 500
with tqdm(total=n_epochs, desc="Training") as pbar:
    for epoch in range(n_epochs):
        loss, gradient = grad_fn(opt_params, targets)

        # Optimizer step.
        updates, opt_state = optimizer.update(gradient, opt_state)
        opt_params = optax.apply_updates(opt_params, updates)
        pbar.set_postfix(
            {"Loss": f"{loss:.4f}", "Accuracy": f"{get_accuracy(opt_params):.3f}"}
        )
        pbar.update()
Training: 100%|████████████████████████| 500/500 [12:19<00:00,  1.48s/it, Loss=0.2618, Accuracy=0.883]

Visualizing output voltages#

Finally, we run a simulation with the learned parameters and plot the output membrane potentials across time. Since output neurons don’t fire in this example (we never inserted Fire channles into them), we inspect their voltages directly to decide the predicted class. The predicted class is determined by the neuron’s maximum voltage.

v = simulate_network(opt_params, current_batches[0])

time_vec = jnp.arange(0, T + 2 * dt, dt)
fig, ax = plt.subplots(1, 1, figsize=(5, 2))
_ = plt.plot(time_vec, v.T)
plt.show()
../_images/2a83ab9cf12574ce04e7032d1495a7a6a3ea29cf7c8fd127b52be43a529908c8.png