jaxley.Network#

class Network(cells, vectorize_cells=None)[source]#

Bases: Module

A network made up of multiple cells, connected by synapses.

This class defines a network of cells. These cells can later on be connected with synapses via jx.connect.

Parameters:
network_params: Dict = {}#
network_states: Dict = {}#
arrange_in_layers(layers, within_layer_offset=500.0, between_layer_offset=1500.0, vertical_layers=False)[source]#

Arrange the cells in the network to form layers.

Moves the cells in the network to arrange them into layers.

Parameters:
  • layers (List[int]) – List of integers specifying the number of cells in each layer.

  • within_layer_offset (float) – Offset between cells within the same layer.

  • between_layer_offset (float) – Offset between layers.

  • vertical_layers (bool) – If True, layers are arranged vertically.

vis(detail='full', ax=None, color='k', synapse_color='b', dims=(0, 1), cell_plot_kwargs={}, synapse_plot_kwargs={}, synapse_scatter_kwargs={}, **kwargs)[source]#

Visualize the module.

Parameters:
  • detail (str) – Either of [point, full]. point visualizes every neuron in the network as a dot. full plots the full morphology of every neuron. It requires that compute_xyz() has been run.

  • color (str) – The color in which cells are plotted.

  • synapse_color (str) – The color in which synapses are plotted.

  • dims (Tuple[int]) – Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.

  • cell_plot_kwargs (Dict) – Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for detail=’full’.

  • synapse_plot_kwargs (Dict) – Keyword arguments passed to the plotting function for syanpses.

  • synapse_scatter_kwargs (Dict) – Keyword arguments passed to the scatter function for syanpse terminals.

  • ax (Axes | None)

Return type:

Axes

add_to_group(group_name)#

Add a view of the module to a group.

Groups can then be indexed. For example:

net.cell(0).add_to_group("excitatory")
net.excitatory.set("radius", 0.1)
Parameters:

group_name (str) – The name of the group.

branch(idx)#

Return a View of the module at the selected branches(s).

Parameters:

idx (Any) – index of the branch to view.

Returns:

View of the module at the specified branch index.

Return type:

View

property branches#

Iterate over all branches in the module.

Returns a generator that yields a View of each branch.

cell(idx)#

Return a View of the module at the selected cell(s).

Parameters:

idx (Any) – index of the cell to view.

Returns:

View of the module at the specified cell index.

Return type:

View

property cells#

Iterate over all cells in the module.

Returns a generator that yields a View of each cell.

clamp(state_name, state_array, verbose=True)#

Clamp a state to a given value across specified compartments.

Parameters:

This function sets external states for the compartments.

comp(idx)#

Return a View of the module at the selected compartments(s).

Parameters:

idx (Any) – index of the comp to view.

Returns:

View of the module at the specified compartment index.

Return type:

View

property comps#

Iterate over all compartments in the module. Can be called on any module, i.e. net.comps, cell.comps or branch.comps. __iter__ does not allow for this.

Returns a generator that yields a View of each compartment.

compute_compartment_centers()#

Add compartment centers to nodes dataframe

compute_xyz()#

Return xyz coordinates of every branch, based on the branch length.

This function should not be called if the morphology was read from an .swc file. However, for morphologies that were constructed from scratch, this function must be called before .vis(). The computed xyz coordinates are only used for plotting.

copy(reset_index=False, as_module=False)#

Extract part of a module and return a copy of its View or a new module.

This can be used to call jx.integrate on part of a Module.

Parameters:
  • reset_index (bool) – if True, the indices of the new module are reset to start from 0.

  • as_module (bool) – if True, a new module is returned instead of a View.

Returns:

A part of the module or a copied view of it.

Return type:

Module | View

copy_node_property_to_edges(properties_to_import, pre_or_post=['pre', 'post'])#

Copy a property that is in node over to edges.

By default, .edges does not contain the properties (radius, length, cm, channel properties,…) of the pre- and post-synaptic compartments. This method allows to copy a property of the pre- and/or post-synaptic compartment to the edges. It is then accessible as module.edges.pre_property_name or module.edges.post_property_name.

Note that, if you modify the node property _after_ having run copy_node_property_to_edges, it will not automatically update the value in .edges.

Note that, if this method is called on a View (e.g. net.cell(0).copy_node_property_to_edges), then it will return a View, but it will _not_ modify the module itself.

Parameters:
  • properties_to_import (str | List[str]) – The name of the node properties that should be imported. To list all available properties, look at module.nodes.columns.

  • pre_or_post (str | List[str]) – Whether to import only the pre-synaptic property (‘pre’), only the post-synaptic property (‘post’), or both ([‘pre’, ‘post’]).

Returns:

A new module which has the property copied to the nodes.

Return type:

Module

data_clamp(state_name, state_array, data_clamps=None, verbose=False)#

Insert a clamp into the module within jit (or grad).

Parameters:
  • state_name (str) – Name of the state variable to set.

  • state_array (Array | ndarray | bool | number | bool | int | float | complex) – Time series of the state variable in the default Jaxley unit. State array should be of shape (num_clamps, simulation_time) or (simulation_time, ) for a single clamp.

  • verbose (bool) – Whether or not to print the number of inserted clamps. False by default because this method is meant to be jitted.

  • data_clamps (tuple[Array | ndarray | bool | number | bool | int | float | complex, DataFrame] | None)

data_set(key, val, param_state)#

Set parameter of module (or its view) to a new value within jit.

Parameters:
  • key (str) – The name of the parameter to set.

  • val (float | Array | ndarray | bool | number | bool | int | complex) – The value to set the parameter to. If it is ArrayLike then it must be of shape (len(num_compartments)).

  • param_state (list[dict] | None) – State of the set parameters, internally used such that this function does not modify global state.

data_stimulate(current, data_stimuli=None, verbose=False)#

Insert a stimulus into the module within jit (or grad).

Parameters:
Return type:

tuple[Array, DataFrame]

delete(channel)#

Remove a channel or pump from the module.

Parameters:

channel (Channel | Pump) – The channel to remove.

Example usage#

The example below inserts two channels and then deletes the sodium channel:

from jaxley.channels import Na, K

cell = jx.Cell()
cell.insert(Na())
cell.insert(K())

cell.delete(Na())

The example below inserts two channels and then deletes both of them:

from jaxley.channels import Na, K

cell = jx.Cell()
cell.insert(Na())
cell.insert(K())

# Loop over all channels and delete each one. Note: The `list(...)` is
# important here because `.delete()` modifies the `.channels`.
for channel in list(net.channels):
    cell.delete(channel)
delete_clamps(state_name=None)#

Removes all clamps of the given state from the module.

Parameters:

state_name (str | None)

delete_diffusion(*args, **kwargs)#
delete_recordings()#

Removes all recordings from the module.

delete_stimuli()#

Removes all stimuli from the module.

delete_trainables()#

Removes all trainable parameters from the module.

diffuse(*args, **kwargs)#
distance(**kwargs)#
edge(idx)#

Return a View of the module at the selected synapse edges(s).

Parameters:

idx (Any) – index of the edge to view.

Returns:

View of the module at the specified edge index.

Return type:

View

get_all_parameters(*args, **kwargs)#
get_all_states(*args, **kwargs)#
get_parameters()#

Get all trainable parameters.

The returned parameters should be passed to jx.integrate(..., params=params).

Returns:

A list of all trainable parameters in the form of

[{“gNa”: jnp.array([0.1, 0.2, 0.3])}, …].

Return type:

list[dict[str, Array]]

init_states(*args, **kwargs)#
initialize()#

Initialize the module.

This function does several things: 1) It computes local indices in the .nodes dataframe (from global indices). 2) It builds the compartment graph (._comp_edges and ._branchpoints). 3) It initializes the View. 4) It initializes all solvers required for solving the differential equation.

This function should be run whenever the graph-structure (i.e., the morphology or the compartmentalization) of the module have been changed. Inbuilt functions such as morph_attach(), morph_delete(), or set_ncomp() run this function automatically though, so there is no need for the user to run it manually.

property initialized: bool#

Whether the Module is ready to be solved or not.

insert(channel)#

Insert a channel or pump into the module.

Parameters:

channel (Channel | Pump) – The channel to insert.

loc(at)#

Return a View of the module at the selected branch location(s).

Parameters:

at (Any) – location along the branch.

Returns:

View of the module at the specified branch location.

Return type:

View

make_trainable(key, init_val=None, verbose=True)#

Make a parameter trainable.

If a parameter is made trainable, it will be returned by get_parameters() and should then be passed to jx.integrate(…, params=params).

Parameters:
  • key (str) – Name of the parameter to make trainable.

  • init_val (float | list | None) – Initial value of the parameter. If float, the same value is used for every created parameter. If list, the length of the list has to match the number of created parameters. If None, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.

  • verbose (bool) – Whether to print the number of parameters that are added and the total number of parameters.

move(x=0.0, y=0.0, z=0.0, update_nodes=False)#

Move cells or networks by adding to their (x, y, z) coordinates.

This function is used only for visualization. It does not affect the simulation.

Parameters:
  • x (float) – The amount to move in the x direction in um.

  • y (float) – The amount to move in the y direction in um.

  • z (float) – The amount to move in the z direction in um.

  • update_nodes (bool) – Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

move_to(x=0.0, y=0.0, z=0.0, update_nodes=False)#

Move cells or networks to a location (x, y, z).

If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment’s previous coordinate and the new float location.

If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.

Parameters:
  • update_nodes (bool) – Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.

  • x (float | ndarray)

  • y (float | ndarray)

  • z (float | ndarray)

record(state='v', verbose=True)#
Parameters:

state (str)

rotate(degrees, rotation_axis='xy', update_nodes=False)#

Rotate jaxley modules clockwise. Used only for visualization.

This function is used only for visualization. It does not affect the simulation.

Parameters:
  • degrees (float) – How many degrees to rotate the module by.

  • rotation_axis (str) – Either of {xy | xz | yz}.

  • update_nodes (bool)

scope(scope)#

Return a View of the module with the specified scope.

For example cell.scope(“global”).branch(2).scope(“local”).comp(1) will return the 1st compartment of branch 2.

Parameters:

scope (str) – either “global” or “local”.

Returns:

View with the specified scope.

Return type:

View

select(nodes=None, edges=None, sorted=False)#

Return View of the module filtered by specific node or edges indices.

The selection is made based on the index of the self.nodes or self.edges, i.e., not on a local compartment index or a local row number (loc, not iloc).

Parameters:
  • nodes (ndarray | None) – indices of nodes to view. If None, all nodes are viewed.

  • edges (ndarray | None) – indices of edges to view. If None, all edges are viewed.

  • sorted (bool) – if True, nodes and edges are sorted.

Returns:

View for subset of selected nodes and/or edges.

Return type:

View

set(key, val)#

Set parameter of module (or its view) to a new value.

Note that this function can not be called within jax.jit or jax.grad. Instead, it should be used set the parameters of the module before the simulation. Use .data_set() to set parameters during jax.jit or jax.grad.

Parameters:
  • key (str) – The name of the parameter to set.

  • val (float | Array | ndarray | bool | number | bool | int | complex) – The value to set the parameter to. If it is ArrayLike then it must be of shape (len(num_compartments)).

set_ncomp(ncomp, min_radius=None, initialize=True)#

Set the number of compartments with which the branch is discretized.

Parameters:
  • ncomp (int) – The number of compartments that the branch should be discretized into.

  • min_radius (float | None) – Only used if the morphology was read from an SWC file. If passed the radius is capped to be at least this value.

  • initialize (bool) – If False, it skips the initialization stage and the user has to run it manually afterwards. This is useful when set_ncomp is run in a loop (e.g. for the d_lambda rule), where one can initialize only once after the entire loop to largely speed up computation time. If False, then the user has to run cell.initialize() manually afterwards.

Raises:
  • - When there are stimuli in any compartment in the module.

  • - When there are recordings in any compartment in the module.

  • - When the channels of the compartments are not the same within the branch

  • that is modified.

  • - When the lengths of the compartments are not the same within the branch

  • that is modified.

  • - When the branch that is modified has compartments belonging to different

  • groups.

  • - Unless the morphology was read from an SWC file, when the radiuses of the

  • compartments are not the same within the branch that is modified.

set_scope(scope)#

Toggle between “global” or “local” scope.

Determines if global or local indices are used for viewing the module.

Parameters:

scope (str) – either “global” or “local”.

property shape: Tuple[int]#

Returns the number of submodules contained in a module.

network.shape = (num_cells, num_branches, num_compartments)
cell.shape = (num_branches, num_compartments)
branch.shape = (num_compartments,)
show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)#

Print detailed information about the Module or a view of it.

Parameters:
  • param_names (str | List[str] | None) – The names of the parameters to show. If None, all parameters are shown.

  • indices (bool) – Whether to show the indices of the compartments.

  • params (bool) – Whether to show the parameters of the compartments.

  • states (bool) – Whether to show the states of the compartments.

  • channel_names (List[str] | None) – The names of the channels to show. If None, all channels are shown.

Returns:

A pd.DataFrame with the requested information.

Return type:

DataFrame

step(*args, **kwargs)#
stimulate(current=None, verbose=True)#

Insert a stimulus into the compartment.

current must be a 1d array or have batch dimension of size (num_compartments, ) or (1, ). If 1d, the same stimulus is added to all compartments.

This function cannot be run during jax.jit and jax.grad. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate().

Parameters:
to_jax(*args, **kwargs)#
property view#

Return view of the module.

write_trainables(trainable_params)#

Write the trainables into .nodes and .edges.

This allows to, e.g., visualize trained networks with .vis().

Parameters:

trainable_params (list[dict[str, Array]]) – The trainable parameters returned by get_parameters().