Modules#
- class Module[source]#
Module base class.
Modules are everything that can be passed to jx.integrate, i.e. compartments, branches, cells, and networks.
This base class defines the scaffold for all jaxley modules (compartments, branches, cells, networks).
Modules can be traversed and modified using the at, cell, branch, comp, edge, and loc methods. The scope method can be used to toggle between global and local indices. Traversal of Modules will return a View of itself, that has a modified set of attributes, which only consider the part of the Module that is in view.
For developers: The above has consequences for how to operate on Module and which changes take affect where. The following guidelines should be followed (copied from View):
We consider a Module to have everything in view.
Views can display and keep track of how a module is traversed. But(!), do not support making changes or setting variables. This still has to be done in the base Module, i.e. self.base. In order to enssure that these changes only affects whatever is currently in view self._nodes_in_view, or self._edges_in_view among others have to be used. Operating on nodes currently in view can for example be done with self.base.node.loc[self._nodes_in_view].
Every attribute of Module that changes based on what’s in view, i.e. xyzr, needs to modified when View is instantiated. I.e. xyzr of cell.branch(0), should be [self.base.xyzr[0]] This could be achieved via: [self.base.xyzr[b] for b in self._branches_in_view].
For developers: If you want to add a new method to Module, here is an example of how to make methods of Module compatible with View:
# Use data in view to return something. def count_small_branches(self): # no need to use self.base.attr + viewed indices, # since no change is made to the attr in question (nodes) comp_lens = self.nodes["length"] branch_lens = comp_lens.groupby("global_branch_index").sum() return np.sum(branch_lens < 10) # Change data in view. def change_attr_in_view(self): # changes to attrs have to be made via self.base.attr + viewed indices a = func1(self.base.attr1[self._cells_in_view]) b = func2(self.base.attr2[self._edges_in_view]) self.base.attr3[self._branches_in_view] = a + b
- select(nodes=None, edges=None, sorted=False)[source]#
Return View of the module filtered by specific node or edges indices.
- Parameters:
- Returns:
View for subset of selected nodes and/or edges.
- Return type:
View
- set_scope(scope)[source]#
Toggle between “global” or “local” scope.
Determines if global or local indices are used for viewing the module.
- Parameters:
scope (str) – either “global” or “local”.
- scope(scope)[source]#
Return a View of the module with the specified scope.
For example cell.scope(“global”).branch(2).scope(“local”).comp(1) will return the 1st compartment of branch 2.
- Parameters:
scope (str) – either “global” or “local”.
- Returns:
View with the specified scope.
- Return type:
View
- cell(idx)[source]#
Return a View of the module at the selected cell(s).
- Parameters:
idx (Any) – index of the cell to view.
- Returns:
View of the module at the specified cell index.
- Return type:
View
- branch(idx)[source]#
Return a View of the module at the selected branches(s).
- Parameters:
idx (Any) – index of the branch to view.
- Returns:
View of the module at the specified branch index.
- Return type:
View
- comp(idx)[source]#
Return a View of the module at the selected compartments(s).
- Parameters:
idx (Any) – index of the comp to view.
- Returns:
View of the module at the specified compartment index.
- Return type:
View
- edge(idx)[source]#
Return a View of the module at the selected synapse edges(s).
- Parameters:
idx (Any) – index of the edge to view.
- Returns:
View of the module at the specified edge index.
- Return type:
View
- loc(at)[source]#
Return a View of the module at the selected branch location(s).
- Parameters:
at (Any) – location along the branch.
- Returns:
View of the module at the specified branch location.
- Return type:
View
- property cells#
Iterate over all cells in the module.
Returns a generator that yields a View of each cell.
- property branches#
Iterate over all branches in the module.
Returns a generator that yields a View of each branch.
- property comps#
Iterate over all compartments in the module. Can be called on any module, i.e. net.comps, cell.comps or branch.comps. __iter__ does not allow for this.
Returns a generator that yields a View of each compartment.
- property shape: Tuple[int]#
Returns the number of submodules contained in a module.
network.shape = (num_cells, num_branches, num_compartments) cell.shape = (num_branches, num_compartments) branch.shape = (num_compartments,)
- copy(reset_index=False, as_module=False)[source]#
Extract part of a module and return a copy of its View or a new module.
This can be used to call jx.integrate on part of a Module.
- property view#
Return view of the module.
- show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)[source]#
Print detailed information about the Module or a view of it.
- Parameters:
param_names (str | List[str] | None) – The names of the parameters to show. If None, all parameters are shown.
indices (bool) – Whether to show the indices of the compartments.
params (bool) – Whether to show the parameters of the compartments.
states (bool) – Whether to show the states of the compartments.
channel_names (List[str] | None) – The names of the channels to show. If None, all channels are shown.
- Returns:
A pd.DataFrame with the requested information.
- Return type:
DataFrame
- set(key, val)[source]#
Set parameter of module (or its view) to a new value.
Note that this function can not be called within jax.jit or jax.grad. Instead, it should be used set the parameters of the module before the simulation. Use .data_set() to set parameters during jax.jit or jax.grad.
- data_set(key, val, param_state)[source]#
Set parameter of module (or its view) to a new value within jit.
- Parameters:
key (str) – The name of the parameter to set.
val (float | Array) – The value to set the parameter to. If it is jnp.ndarray then it must be of shape (len(num_compartments)).
param_state (List[Dict] | None) – State of the setted parameters, internally used such that this function does not modify global state.
- set_ncomp(ncomp, min_radius=None)[source]#
Set the number of compartments with which the branch is discretized.
- Parameters:
- Raises:
- When there are stimuli in any compartment in the module. –
- When there are recordings in any compartment in the module. –
- When the channels of the compartments are not the same within the branch –
that is modified. –
- When the lengths of the compartments are not the same within the branch –
that is modified. –
- Unless the morphology was read from an SWC file, when the radiuses of the –
compartments are not the same within the branch that is modified. –
- make_trainable(key, init_val=None, verbose=True)[source]#
Make a parameter trainable.
If a parameter is made trainable, it will be returned by get_parameters() and should then be passed to jx.integrate(…, params=params).
- Parameters:
key (str) – Name of the parameter to make trainable.
init_val (float | list | None) – Initial value of the parameter. If float, the same value is used for every created parameter. If list, the length of the list has to match the number of created parameters. If None, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.
verbose (bool) – Whether to print the number of parameters that are added and the total number of parameters.
- write_trainables(trainable_params)[source]#
Write the trainables into .nodes and .edges.
This allows to, e.g., visualize trained networks with .vis().
- distance(endpoint)[source]#
Return the direct distance between two compartments. This does not compute the pathwise distance (which is currently not implemented). :param endpoint: The compartment to which to compute the distance to.
- Parameters:
endpoint (View)
- Return type:
- add_to_group(group_name)[source]#
Add a view of the module to a group.
Groups can then be indexed. For example:
net.cell(0).add_to_group("excitatory") net.excitatory.set("radius", 0.1)
- Parameters:
group_name (str) – The name of the group.
- get_parameters()[source]#
Get all trainable parameters.
The returned parameters should be passed to `jx.integrate(…, params=params).
- stimulate(current=None, verbose=True)[source]#
Insert a stimulus into the compartment.
current must be a 1d array or have batch dimension of size (num_compartments, ) or (1, ). If 1d, the same stimulus is added to all compartments.
This function cannot be run during jax.jit and jax.grad. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate().
- clamp(state_name, state_array, verbose=True)[source]#
Clamp a state to a given value across specified compartments.
- Parameters:
This function sets external states for the compartments.
- data_stimulate(current, data_stimuli=None, verbose=False)[source]#
Insert a stimulus into the module within jit (or grad).
- data_clamp(state_name, state_array, data_clamps=None, verbose=False)[source]#
Insert a clamp into the module within jit (or grad).
- Parameters:
state_name (str) – Name of the state variable to set.
state_array (Array) – Time series of the state variable in the default Jaxley unit. State array should be of shape (num_clamps, simulation_time) or (simulation_time, ) for a single clamp.
verbose (bool) – Whether or not to print the number of inserted clamps. False by default because this method is meant to be jitted.
- delete_clamps(state_name=None)[source]#
Removes all clamps of the given state from the module.
- Parameters:
state_name (str | None)
- insert(channel)[source]#
Insert a channel into the module.
- Parameters:
channel (Channel) – The channel to insert.
- delete_channel(channel)[source]#
Remove a channel from the module.
- Parameters:
channel (Channel) – The channel to remove.
- vis(ax=None, col='k', dims=(0, 1), type='line', morph_plot_kwargs={})[source]#
Visualize the module.
Modules can be visualized on one of the cardinal planes (xy, xz, yz) or even in 3D.
Several options are available: - line: All points from the traced morphology (xyzr), are connected with a line plot. - scatter: All traced points, are plotted as scatter points. - comp: Plots the compartmentalized morphology, including radius and shape. (shows the true compartment lengths per default, but this can be changed via the morph_plot_kwargs, for details see jaxley.utils.plot_utils.plot_comps). - morph: Reconstructs the 3D shape of the traced morphology. For details see jaxley.utils.plot_utils.plot_morph. Warning: For 3D plots and morphologies with many traced points this can be very slow.
- Parameters:
ax (Axes | None) – An axis into which to plot.
col (str) – The color for all branches.
dims (Tuple[int]) – Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.
type (str) – The type of plot. One of [“line”, “scatter”, “comp”, “morph”].
morph_plot_kwargs (Dict) – Keyword arguments passed to the plotting function.
- Return type:
Axes
- compute_xyz()[source]#
Return xyz coordinates of every branch, based on the branch length.
This function should not be called if the morphology was read from an .swc file. However, for morphologies that were constructed from scratch, this function must be called before .vis(). The computed xyz coordinates are only used for plotting.
- move(x=0.0, y=0.0, z=0.0, update_nodes=False)[source]#
Move cells or networks by adding to their (x, y, z) coordinates.
This function is used only for visualization. It does not affect the simulation.
- Parameters:
x (float) – The amount to move in the x direction in um.
y (float) – The amount to move in the y direction in um.
z (float) – The amount to move in the z direction in um.
update_nodes (bool) – Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.
- move_to(x=0.0, y=0.0, z=0.0, update_nodes=False)[source]#
Move cells or networks to a location (x, y, z).
If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment’s previous coordinate and the new float location.
If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.
- rotate(degrees, rotation_axis='xy', update_nodes=False)[source]#
Rotate jaxley modules clockwise. Used only for visualization.
This function is used only for visualization. It does not affect the simulation.
- copy_node_property_to_edges(properties_to_import, pre_or_post=['pre', 'post'])[source]#
Copy a property that is in node over to edges.
By default, .edges does not contain the properties (radius, length, cm, channel properties,…) of the pre- and post-synaptic compartments. This method allows to copy a property of the pre- and/or post-synaptic compartment to the edges. It is then accessible as module.edges.pre_property_name or module.edges.post_property_name.
Note that, if you modify the node property _after_ having run copy_node_property_to_edges, it will not automatically update the value in .edges.
Note that, if this method is called on a View (e.g. net.cell(0).copy_node_property_to_edges), then it will return a View, but it will _not_ modify the module itself.
- Parameters:
properties_to_import (str | List[str]) – The name of the node properties that should be imported. To list all available properties, look at module.nodes.columns.
pre_or_post (str | List[str]) – Whether to import only the pre-synaptic property (‘pre’), only the post-synaptic property (‘post’), or both ([‘pre’, ‘post’]).
- Returns:
A new module which has the property copied to the nodes.
- Return type:
- class Compartment[source]#
Compartment class.
This class defines a single compartment that can be simulated by itself or connected up into branches. It is the basic building block of a neuron model.
- class Branch(**kwargs)[source]#
Branch class.
This class defines a single branch that can be simulated by itself or connected to build a cell. A branch is linear segment of several compartments and can be connected to no, one or more other branches at each end to build more intricate cell morphologies.
- class Cell(branches=None, parents=None, xyzr=None)[source]#
Cell class.
This class defines a single cell that can be simulated by itself or connected with synapses to build a network. A cell is made up of several branches and supports intricate cell morphologies.
- class Network(cells)[source]#
Network class.
This class defines a network of cells that can be connected with synapses.
- vis(detail='full', ax=None, col='k', synapse_col='b', dims=(0, 1), type='line', layers=None, morph_plot_kwargs={}, synapse_plot_kwargs={}, synapse_scatter_kwargs={}, networkx_options={}, layer_kwargs={})[source]#
Visualize the module.
- Parameters:
detail (str) – Either of [point, full]. point visualizes every neuron in the network as a dot (and it uses networkx to obtain cell positions). full plots the full morphology of every neuron. It requires that compute_xyz() has been run and allows for indivual neurons to be moved with .move().
col (str) – The color in which cells are plotted. Only takes effect if detail=’full’.
type (str) – Either line or scatter. Only takes effect if detail=’full’.
synapse_col (str) – The color in which synapses are plotted. Only takes effect if detail=’full’.
dims (Tuple[int]) – Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.
layers (List | None) – Allows to plot the network in layers. Should provide the number of neurons in each layer, e.g., [5, 10, 1] would be a network with 5 input neurons, 10 hidden layer neurons, and 1 output neuron.
morph_plot_kwargs (Dict) – Keyword arguments passed to the plotting function for cell morphologies. Only takes effect for detail=’full’.
synapse_plot_kwargs (Dict) – Keyword arguments passed to the plotting function for syanpses. Only takes effect for detail=’full’.
synapse_scatter_kwargs (Dict) – Keyword arguments passed to the scatter function for the end point of synapses. Only takes effect for detail=’full’.
networkx_options (Dict) – Options passed to networkx.draw(). Only takes effect if detail=’point’.
layer_kwargs (Dict) – Only used if layers is specified and if detail=’full’. Can have the following entries: within_layer_offset (float), between_layer_offset (float), vertical_layers (bool).
ax (Axes | None)
- Return type:
Axes