Defining groups#

In this tutorial, you will learn how to:

  • define groups (aka sectionlists) to simplify interactions with Jaxley

Here is a code snippet which you will learn to understand in this tutorial:

from jax import jit, vmap


net = ...  # See tutorial on Basics of Jaxley.

net.cell(0).add_to_group("fast_spiking")
net.cell(1).add_to_group("slow_spiking")

def simulate(params):
    param_state = None
    param_state = net.fast_spiking.data_set("HH_gNa", params[0], param_state)
    param_state = net.slow_spiking.data_set("HH_gNa", params[1], param_state)
    return jx.integrate(net, param_state=param_state)

# Define sodium for fast and slow spiking neurons.
params = jnp.asarray([1.0, 0.1])

# Run simulation.
voltages = simulate(params)

In many cases, you might want to group several compartments (or branches, or cells) and assign a unique parameter or mechanism to this group. For example, you might want to define a couple of branches as basal and then assign a Hodgkin-Huxley mechanism only to those branches. Or you might define a couple of cells as fast spiking and assign them a high value for the sodium conductance. We describe how you can do this in this tutorial.

from jax import config
config.update("jax_enable_x64", True)
config.update("jax_platform_name", "cpu")

import time
import matplotlib.pyplot as plt
import numpy as np
import jax
import jax.numpy as jnp
from jax import jit, value_and_grad

import jaxley as jx
from jaxley.channels import Na, K, Leak
from jaxley.synapses import IonotropicSynapse
from jaxley.connect import fully_connect

First, we define a network as you saw in the previous tutorial:

comp = jx.Compartment()
branch = jx.Branch(comp, ncomp=2)
cell = jx.Cell(branch, parents=[-1, 0, 0, 1])
network = jx.Network([cell for _ in range(3)])

pre = network.cell([0, 1])
post = network.cell([2])
fully_connect(pre, post, IonotropicSynapse())

network.insert(Na())
network.insert(K())
network.insert(Leak())

Group: apical dendrites#

Assume that, in each of the five neurons in this network, the second and forth branch are apical dendrites. We can define this as:

for cell_ind in range(3):
    network.cell(cell_ind).branch(1).add_to_group("apical")
    network.cell(cell_ind).branch(3).add_to_group("apical")

After this, we can access network.apical as we previously accesses anything else:

network.apical.set("radius", 0.3)
network.apical.nodes
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Na Na_gNa ... eK K_n Leak Leak_gLeak Leak_eLeak apical global_cell_index global_branch_index global_comp_index controlled_by_param
2 0 0 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 0 1 2 0
3 0 0 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 0 1 3 0
6 0 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 0 3 6 0
7 0 1 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 0 3 7 0
10 1 0 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 1 5 10 0
11 1 0 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 1 5 11 0
14 1 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 1 7 14 0
15 1 1 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 1 7 15 0
18 2 0 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 2 9 18 0
19 2 0 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 2 9 19 0
22 2 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 2 11 22 0
23 2 1 1 10.0 0.3 5000.0 1.0 -70.0 True 0.05 ... -90.0 0.2 True 0.0001 -70.0 True 2 11 23 0

12 rows × 26 columns

Group: fast spiking#

Similarly, you could define a group of fast-spiking cells. Assume that the first and second cell are fast-spiking:

network.cell(0).add_to_group("fast_spiking")
network.cell(1).add_to_group("fast_spiking")
network.fast_spiking.set("Na_gNa", 0.4)
network.fast_spiking.nodes
local_cell_index local_branch_index local_comp_index length radius axial_resistivity capacitance v Na Na_gNa ... K_n Leak Leak_gLeak Leak_eLeak apical fast_spiking global_cell_index global_branch_index global_comp_index controlled_by_param
0 0 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 False True 0 0 0 0
1 0 0 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 False True 0 0 1 0
2 0 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 True True 0 1 2 0
3 0 1 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 True True 0 1 3 0
4 0 2 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 False True 0 2 4 0
5 0 2 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 False True 0 2 5 0
6 0 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 True True 0 3 6 0
7 0 3 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 True True 0 3 7 0
8 1 0 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 False True 1 4 8 0
9 1 0 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 False True 1 4 9 0
10 1 1 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 True True 1 5 10 0
11 1 1 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 True True 1 5 11 0
12 1 2 0 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 False True 1 6 12 0
13 1 2 1 10.0 1.0 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 False True 1 6 13 0
14 1 3 0 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 True True 1 7 14 0
15 1 3 1 10.0 0.3 5000.0 1.0 -70.0 True 0.4 ... 0.2 True 0.0001 -70.0 True True 1 7 15 0

16 rows × 27 columns

Groups from SWC files#

If you are reading .swc morphologigies, you can automatically assign groups with

jx.read_swc(file_name, ncomp=n, assign_groups=True)  # assign_groups=True is the default

After that, you can directly use cell.soma, cell.apical, cell.basal, or cell.axon.

How groups are interpreted by .make_trainable()#

If you make a parameter of a group trainable, then it will be treated as a single shared parameter for a given property:

network.fast_spiking.make_trainable("Na_gNa")
Number of newly added trainable parameters: 1. Total number of trainable parameters: 1

As such, get_parameters() returns only a single trainable parameter, which will be the sodium conductance for every compartment of every fast-spiking neuron:

network.get_parameters()
[{'Na_gNa': Array([0.4], dtype=float64)}]

If, instead, you would want a separate parameter for every fast-spiking cell, you should not use the group, but instead do the following (remember that fast-spiking neurons had indices [0,1]):

network.cell([0,1]).make_trainable("axial_resistivity")
Number of newly added trainable parameters: 2. Total number of trainable parameters: 3
network.get_parameters()
[{'Na_gNa': Array([0.4], dtype=float64)},
 {'axial_resistivity': Array([5000., 5000.], dtype=float64)}]

This generated two parameters for the axial resistivitiy, each corresponding to one cell.

Summary#

Groups allow you to organize your simulation in a more intuitive way, and they allow to perform parameter sharing with make_trainable().