Modules#

class Module[source]#

Module base class.

Modules are everything that can be passed to jx.integrate, i.e. compartments, branches, cells, and networks.

This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks).

Modules can be traversed and modified using the at, cell, branch, comp, edge, and loc methods. The scope method can be used to toggle between global and local indices. Traversal of Modules will return a View of itself, that has a modified set of attributes, which only consider the part of the Module that is in view.

For developers: The above has consequences for how to operate on Module and which changes take affect where. The following guidelines should be followed (copied from View):

  1. We consider a Module to have everything in view.

  2. Views can display and keep track of how a module is traversed. But(!), do not support making changes or setting variables. This still has to be done in the base Module, i.e. self.base. In order to enssure that these changes only affects whatever is currently in view self._nodes_in_view, or self._edges_in_view among others have to be used. Operating on nodes currently in view can for example be done with self.base.node.loc[self._nodes_in_view].

  3. Every attribute of Module that changes based on what’s in view, i.e. xyzr, needs to modified when View is instantiated. I.e. xyzr of cell.branch(0), should be [self.base.xyzr[0]] This could be achieved via: [self.base.xyzr[b] for b in self._branches_in_view].

For developers: If you want to add a new method to Module, here is an example of how to make methods of Module compatible with View:

# Use data in view to return something.
def count_small_branches(self):
    # no need to use self.base.attr + viewed indices,
    # since no change is made to the attr in question (nodes)
    comp_lens = self.nodes["length"]
    branch_lens = comp_lens.groupby("global_branch_index").sum()
    return np.sum(branch_lens < 10)

# Change data in view.
def change_attr_in_view(self):
    # changes to attrs have to be made via self.base.attr + viewed indices
    a = func1(self.base.attr1[self._cells_in_view])
    b = func2(self.base.attr2[self._edges_in_view])
    self.base.attr3[self._branches_in_view] = a + b
compute_compartment_centers()[source]#

Add compartment centers to nodes dataframe

select(nodes=None, edges=None, sorted=False)[source]#

Return View of the module filtered by specific node or edges indices.

The selection is made based on the index of the self.nodes or self.edges, i.e., not on a local compartment index or a local row number (loc, not iloc).

Parameters:
  • nodes (ndarray | None) – indices of nodes to view. If None, all nodes are viewed.

  • edges (ndarray | None) – indices of edges to view. If None, all edges are viewed.

  • sorted (bool) – if True, nodes and edges are sorted.

Returns:

View for subset of selected nodes and/or edges.

Return type:

View

set_scope(scope)[source]#

Toggle between “global” or “local” scope.

Determines if global or local indices are used for viewing the module.

Parameters:

scope (str) – either “global” or “local”.

scope(scope)[source]#

Return a View of the module with the specified scope.

For example cell.scope(“global”).branch(2).scope(“local”).comp(1) will return the 1st compartment of branch 2.

Parameters:

scope (str) – either “global” or “local”.

Returns:

View with the specified scope.

Return type:

View

cell(idx)[source]#

Return a View of the module at the selected cell(s).

Parameters:

idx (Any) – index of the cell to view.

Returns:

View of the module at the specified cell index.

Return type:

View

branch(idx)[source]#

Return a View of the module at the selected branches(s).

Parameters:

idx (Any) – index of the branch to view.

Returns:

View of the module at the specified branch index.

Return type:

View

comp(idx)[source]#

Return a View of the module at the selected compartments(s).

Parameters:

idx (Any) – index of the comp to view.

Returns:

View of the module at the specified compartment index.

Return type:

View

edge(idx)[source]#

Return a View of the module at the selected synapse edges(s).

Parameters:

idx (Any) – index of the edge to view.

Returns:

View of the module at the specified edge index.

Return type:

View

loc(at)[source]#

Return a View of the module at the selected branch location(s).

Parameters:

at (Any) – location along the branch.

Returns:

View of the module at the specified branch location.

Return type:

View

property cells#

Iterate over all cells in the module.

Returns a generator that yields a View of each cell.

property branches#

Iterate over all branches in the module.

Returns a generator that yields a View of each branch.

property comps#

Iterate over all compartments in the module. Can be called on any module, i.e. net.comps, cell.comps or branch.comps. __iter__ does not allow for this.

Returns a generator that yields a View of each compartment.

property shape: Tuple[int]#

Returns the number of submodules contained in a module.

network.shape = (num_cells, num_branches, num_compartments)
cell.shape = (num_branches, num_compartments)
branch.shape = (num_compartments,)
copy(reset_index=False, as_module=False)[source]#

Extract part of a module and return a copy of its View or a new module.

This can be used to call jx.integrate on part of a Module.

Parameters:
  • reset_index (bool) – if True, the indices of the new module are reset to start from 0.

  • as_module (bool) – if True, a new module is returned instead of a View.

Returns:

A part of the module or a copied view of it.

Return type:

Module | View

property view#

Return view of the module.

show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)[source]#

Print detailed information about the Module or a view of it.

Parameters:
  • param_names (str | List[str] | None) – The names of the parameters to show. If None, all parameters are shown.

  • indices (bool) – Whether to show the indices of the compartments.

  • params (bool) – Whether to show the parameters of the compartments.

  • states (bool) – Whether to show the states of the compartments.

  • channel_names (List[str] | None) – The names of the channels to show. If None, all channels are shown.

Returns:

A pd.DataFrame with the requested information.

Return type:

DataFrame

set(key, val)[source]#

Set parameter of module (or its view) to a new value.

Note that this function can not be called within jax.jit or jax.grad. Instead, it should be used set the parameters of the module before the simulation. Use .data_set() to set parameters during jax.jit or jax.grad.

Parameters:
  • key (str) – The name of the parameter to set.

  • val (float | Array) – The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

data_set(key, val, param_state)[source]#

Set parameter of module (or its view) to a new value within jit.

Parameters:
  • key (str) – The name of the parameter to set.

  • val (float | Array) – The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).

  • param_state (List[Dict] | None) – State of the setted parameters, internally used such that this function does not modify global state.

set_ncomp(ncomp, min_radius=None)[source]#

Set the number of compartments with which the branch is discretized.

Parameters:
  • ncomp (int) – The number of compartments that the branch should be discretized into.

  • min_radius (float | None) – Only used if the morphology was read from an SWC file. If passed the radius is capped to be at least this value.

Raises:
  • - When there are stimuli in any compartment in the module.

  • - When there are recordings in any compartment in the module.

  • - When the channels of the compartments are not the same within the branch

  • that is modified.

  • - When the lengths of the compartments are not the same within the branch

  • that is modified.

  • - When the branch that is modified has compartments belonging to different

  • groups.

  • - Unless the morphology was read from an SWC file, when the radiuses of the

  • compartments are not the same within the branch that is modified.

make_trainable(key, init_val=None, verbose=True)[source]#

Make a parameter trainable.

If a parameter is made trainable, it will be returned by get_parameters() and should then be passed to jx.integrate(…, params=params).

Parameters:
  • key (str) – Name of the parameter to make trainable.

  • init_val (float | list | None) – Initial value of the parameter. If float, the same value is used for every created parameter. If list, the length of the list has to match the number of created parameters. If None, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.

  • verbose (bool) – Whether to print the number of parameters that are added and the total number of parameters.

write_trainables(trainable_params)[source]#

Write the trainables into .nodes and .edges.

This allows to, e.g., visualize trained networks with .vis().

Parameters:

trainable_params (List[Dict[str, Array]]) – The trainable parameters returned by get_parameters().

distance(endpoint)[source]#

Return the direct distance between two compartments. This does not compute the pathwise distance (which is currently not implemented). :param endpoint: The compartment to which to compute the distance to.

Parameters:

endpoint (View)

Return type:

float

delete_trainables()[source]#

Removes all trainable parameters from the module.

add_to_group(group_name)[source]#

Add a view of the module to a group.

Groups can then be indexed. For example:

net.cell(0).add_to_group("excitatory")
net.excitatory.set("radius", 0.1)
Parameters:

group_name (str) – The name of the group.

get_parameters()[source]#

Get all trainable parameters.

The returned parameters should be passed to `jx.integrate(…, params=params).

Returns:

A list of all trainable parameters in the form of

[{“gNa”: jnp.array([0.1, 0.2, 0.3])}, …].

Return type:

List[Dict[str, Array]]

property initialized: bool#

Whether the Module is ready to be solved or not.

delete_recordings()[source]#

Removes all recordings from the module.

stimulate(current=None, verbose=True)[source]#

Insert a stimulus into the compartment.

current must be a 1d array or have batch dimension of size (num_compartments, ) or (1, ). If 1d, the same stimulus is added to all compartments.

This function cannot be run during jax.jit and jax.grad. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate().

Parameters:
  • current (Array | None) – Current in nA.

  • verbose (bool)

clamp(state_name, state_array, verbose=True)[source]#

Clamp a state to a given value across specified compartments.

Parameters:
  • state_name (str) – The name of the state to clamp.

  • (jnp.nd (state_array) – Array of values to clamp the state to.

  • verbose (bool) – If True, prints details about the clamping.

  • state_array (Array)

This function sets external states for the compartments.

data_stimulate(current, data_stimuli=None, verbose=False)[source]#

Insert a stimulus into the module within jit (or grad).

Parameters:
  • current (Array) – Current in nA.

  • verbose (bool) – Whether or not to print the number of inserted stimuli. False by default because this method is meant to be jitted.

  • data_stimuli (Tuple[Array, DataFrame] | None)

Return type:

Tuple[Array, DataFrame]

data_clamp(state_name, state_array, data_clamps=None, verbose=False)[source]#

Insert a clamp into the module within jit (or grad).

Parameters:
  • state_name (str) – Name of the state variable to set.

  • state_array (Array) – Time series of the state variable in the default Jaxley unit. State array should be of shape (num_clamps, simulation_time) or (simulation_time, ) for a single clamp.

  • verbose (bool) – Whether or not to print the number of inserted clamps. False by default because this method is meant to be jitted.

  • data_clamps (Tuple[Array, DataFrame] | None)

delete_stimuli()[source]#

Removes all stimuli from the module.

delete_clamps(state_name=None)[source]#

Removes all clamps of the given state from the module.

Parameters:

state_name (str | None)

insert(channel)[source]#

Insert a channel or pump into the module.

Parameters:

channel (Channel | Pump) – The channel to insert.

delete(channel)[source]#

Remove a channel or pump from the module.

Parameters:

channel (Channel | Pump) – The channel to remove.

vis(ax=None, color='k', dims=(0, 1), type='line', **kwargs)[source]#

Visualize the module.

Modules can be visualized on one of the cardinal planes (xy, xz, yz) or even in 3D.

Several options are available: - line: All points from the traced morphology (xyzr), are connected with a line plot. - scatter: All traced points, are plotted as scatter points. - comp: Plots the compartmentalized morphology, including radius and shape. (shows the true compartment lengths per default, but this can be changed via the kwargs, for details see jaxley.utils.plot_utils.plot_comps). - morph: Reconstructs the 3D shape of the traced morphology. For details see jaxley.utils.plot_utils.plot_morph. Warning: For 3D plots and morphologies with many traced points this can be very slow.

Parameters:
  • ax (Axes | None) – An axis into which to plot.

  • color (str) – The color for all branches.

  • dims (Tuple[int]) – Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

  • type (str) – The type of plot. One of [“line”, “scatter”, “comp”, “morph”].

  • kwargs – Keyword arguments passed to the plotting function.

Return type:

Axes

compute_xyz()[source]#

Return xyz coordinates of every branch, based on the branch length.

This function should not be called if the morphology was read from an .swc file. However, for morphologies that were constructed from scratch, this function must be called before .vis(). The computed xyz coordinates are only used for plotting.

move(x=0.0, y=0.0, z=0.0, update_nodes=False)[source]#

Move cells or networks by adding to their (x, y, z) coordinates.

This function is used only for visualization. It does not affect the simulation.

Parameters:
  • x (float) – The amount to move in the x direction in um.

  • y (float) – The amount to move in the y direction in um.

  • z (float) – The amount to move in the z direction in um.

  • update_nodes (bool) – Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

move_to(x=0.0, y=0.0, z=0.0, update_nodes=False)[source]#

Move cells or networks to a location (x, y, z).

If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment’s previous coordinate and the new float location.

If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.

Parameters:
  • update_nodes (bool) – Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

  • x (float | ndarray)

  • y (float | ndarray)

  • z (float | ndarray)

rotate(degrees, rotation_axis='xy', update_nodes=False)[source]#

Rotate jaxley modules clockwise. Used only for visualization.

This function is used only for visualization. It does not affect the simulation.

Parameters:
  • degrees (float) – How many degrees to rotate the module by.

  • rotation_axis (str) – Either of {xy | xz | yz}.

  • update_nodes (bool)

copy_node_property_to_edges(properties_to_import, pre_or_post=['pre', 'post'])[source]#

Copy a property that is in node over to edges.

By default, .edges does not contain the properties (radius, length, cm, channel properties,…) of the pre- and post-synaptic compartments. This method allows to copy a property of the pre- and/or post-synaptic compartment to the edges. It is then accessible as module.edges.pre_property_name or module.edges.post_property_name.

Note that, if you modify the node property _after_ having run copy_node_property_to_edges, it will not automatically update the value in .edges.

Note that, if this method is called on a View (e.g. net.cell(0).copy_node_property_to_edges), then it will return a View, but it will _not_ modify the module itself.

Parameters:
  • properties_to_import (str | List[str]) – The name of the node properties that should be imported. To list all available properties, look at module.nodes.columns.

  • pre_or_post (str | List[str]) – Whether to import only the pre-synaptic property (‘pre’), only the post-synaptic property (‘post’), or both ([‘pre’, ‘post’]).

Returns:

A new module which has the property copied to the nodes.

Return type:

Module

class Compartment[source]#

Compartment class.

This class defines a single compartment that can be simulated by itself or connected up into branches. It is the basic building block of a neuron model.

class Branch(compartments=None, ncomp=None)[source]#

Branch class.

This class defines a single branch that can be simulated by itself or connected to build a cell. A branch is linear segment of several compartments and can be connected to no, one or more other branches at each end to build more intricate cell morphologies.

Parameters:
class Cell(branches=None, parents=None, xyzr=None)[source]#

Cell class.

This class defines a single cell that can be simulated by itself or connected with synapses to build a network. A cell is made up of several branches and supports intricate cell morphologies.

Parameters:
class Network(cells)[source]#

Network class.

This class defines a network of cells that can be connected with synapses.

Parameters:

cells (List[Cell])

arrange_in_layers(layers, within_layer_offset=500.0, between_layer_offset=1500.0, vertical_layers=False)[source]#

Arrange the cells in the network to form layers.

Moves the cells in the network to arrange them into layers.

Parameters:
  • layers (List[int]) – List of integers specifying the number of cells in each layer.

  • within_layer_offset (float) – Offset between cells within the same layer.

  • between_layer_offset (float) – Offset between layers.

  • vertical_layers (bool) – If True, layers are arranged vertically.

vis(detail='full', ax=None, color='k', synapse_color='b', dims=(0, 1), cell_plot_kwargs={}, synapse_plot_kwargs={}, synapse_scatter_kwargs={}, **kwargs)[source]#

Visualize the module.

Parameters:
  • detail (str) – Either of [point, full]. point visualizes every neuron in the network as a dot. full plots the full morphology of every neuron. It requires that compute_xyz() has been run.

  • color (str) – The color in which cells are plotted.

  • synapse_color (str) – The color in which synapses are plotted.

  • dims (Tuple[int]) – Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

  • cell_plot_kwargs (Dict) – Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for detail=’full’.

  • synapse_plot_kwargs (Dict) – Keyword arguments passed to the plotting function for syanpses.

  • synapse_scatter_kwargs (Dict) – Keyword arguments passed to the scatter function for syanpse terminals.

  • ax (Axes | None)

Return type:

Axes