jaxley.Compartment#
- class Compartment[source]#
Bases:
ModuleA single compartment.
This class defines a single compartment that can be simulated by itself or connected up into branches. It is the basic building block of a neuron model.
- compartment_params: Dict = {'axial_resistivity': 5000.0, 'capacitance': 1.0, 'length': 10.0, 'radius': 1.0}#
- add_to_group(group_name)#
Add a view of the module to a group.
- Parameters:
group_name (str) – The name of the group.
Example usage
Define an excitatory group and use it to set parameters.
# net = ... net.cell(0).add_to_group("excitatory") net.excitatory.set("radius", 0.1)
- branch(idx)#
Return a View of the module at the selected branches(s).
- Parameters:
idx (Any) – index of the branch to view.
- Returns:
View of the module at the specified branch index.
- Return type:
View
- property branches: Iterator[View]#
Iterate over all branches in the module.
Returns a generator that yields a View of each branch.
Example usage
Iterating over the branches of a cell:
comp = jx.Compartment() branch1 = jx.Branch([comp] * 2) branch2 = jx.Branch([comp] * 3) cell = jx.Cell([branch1, branch1, branch2], parents=[-1, 0, 0]) nodes_per_branch = [len(branch.nodes) for branch in cell.branches]
- build_exp_euler_transition_matrix(delta_t, axial_conductances=None)#
Compute the exponential of the transition matrix of the voltage diffusion.
For the linear ODE
\[\frac{dv}{dt} = G \, v(t),\]an exponential Euler step of size \(dt\) is given by
\[v(t + dt) = e^{G \, dt} \, v(t).\]The returned matrix is already stripped of entries corresponding to branch points, i.e. it has shape
(ncomp, ncomp). It is intended to be applied tovoltages[self._internal_node_inds].- Parameters:
delta_t (float) – The time step used to compute the matrix exponential e^{G * dt}.
axial_conductances (Array | None) – An array which contains the the axial conductances for every compartment edge (including those from and to branchpoints), for voltage and for all diffused ions. Shape (N+1, C), where N is the number of diffused ions (+1 for voltage), and C is the number of edges. To obtain this term, run module.get_all_parameters().
- Returns:
The transition matrix for the voltage and every diffused state as a matrix of shape (N+1, M, M), where N is the number of diffused states and M is the number of compartments.
- Return type:
- cell(idx)#
Return a View of the module at the selected cell(s).
- Parameters:
idx (Any) – index of the cell to view.
- Returns:
View of the module at the specified cell index.
- Return type:
View
- property cells: Iterator[View]#
Iterate over all cells in the module.
Returns a generator that yields a View of each cell.
Example usage
Iterating over the cells of a network:
cell = jx.Cell() net = jx.Network([cell] * 4) for cell in net.cells: print("Radius of cell: ", len(cell.nodes["radius"]))
- clamp(state_name, state_array, verbose=True)#
Clamp a state to a given value across specified compartments.
- Parameters:
This function sets external states for the compartments.
- comp(idx)#
Return a View of the module at the selected compartments(s).
- Parameters:
idx (Any) – index of the comp to view.
- Returns:
View of the module at the specified compartment index.
- Return type:
View
- property comps: Iterator[View]#
Iterate over all compartments in the module. Can be called on any module, i.e. net.comps, cell.comps or branch.comps. __iter__ does not allow for this.
Returns a generator that yields a View of each compartment.
Example usage
Iterating over the compartments of a branch:
comp = jx.Compartment() branch = jx.Branch([comp] * 2) branch.comp(0).insert(Na()) branch.comp(0).insert(K()) num_channels_per_comp = [len(comp.channels) for comp in branch.comps]
Iterating over the compartments of a cell:
comp = jx.Compartment() comp.insert(Na()) cell = jx.Cell([comp], parents=[-1]) cell.set_scope("global") num_channels_per_comp = [len(comp.channels) for comp in cell.comps]
Iterating over the compartments of a network:
cell1 = jx.Cell() cell1.insert(Na()) cell2 = jx.Cell() cell2.insert(K()) cell2.insert(Na()) net = jx.Network([cell1, cell1, cell2]) net.set_scope("global") num_channels_per_comp = [len(comp.channels) for comp in net.comps]
Note above that you need to set the network and branch scopes to global in order to get an iterator for all the compartments in the network.
- compute_compartment_centers()#
Add compartment centers to nodes dataframe
- compute_xyz()#
Return xyz coordinates of every branch, based on the branch length.
This function should not be called if the morphology was read from an .swc file. However, for morphologies that were constructed from scratch, this function must be called before .vis(). The computed xyz coordinates are only used for plotting.
- copy(reset_index=False, as_module=False)#
Extract part of a module and return a copy of its View or a new module.
This can be used to call jx.integrate on part of a Module.
- copy_node_property_to_edges(properties_to_import, pre_or_post=['pre', 'post'])#
Copy a property that is in node over to edges.
By default, .edges does not contain the properties (radius, length, cm, channel properties,…) of the pre- and post-synaptic compartments. This method allows to copy a property of the pre- and/or post-synaptic compartment to the edges. It is then accessible as module.edges.pre_property_name or module.edges.post_property_name.
Note that, if you modify the node property _after_ having run copy_node_property_to_edges, it will not automatically update the value in .edges.
Note that, if this method is called on a View (e.g. net.cell(0).copy_node_property_to_edges), then it will return a View, but it will _not_ modify the module itself.
- Parameters:
properties_to_import (str | List[str]) – The name of the node properties that should be imported. To list all available properties, look at module.nodes.columns.
pre_or_post (str | List[str]) – Whether to import only the pre-synaptic property (‘pre’), only the post-synaptic property (‘post’), or both ([‘pre’, ‘post’]).
- Returns:
A new module which has the property copied to the nodes.
- Return type:
Module
- customize_solver_exp_euler(exp_euler_transition=None)#
Sets internal attributes which customize the exponential Euler solver.
This function only takes effect when
jx.integrate(..., solver='exp_euler').The current state of these arguments is stored in
module.solver_customizers.- Parameters:
exp_euler_transition (Array | None) – A matrix of shape (ncomp x ncomp), where
ncompis the number of compartments. This matrix is returned bymodule.build_exp_euler_transition_matrix(delta_t). If passed, the matrix will _not_ be computed at the beginning ofjx.integrate(). This can provide massive speed-ups, but it requires that the capacitance, axial resistivity, length, and radius or every compartment is known upfront (i.e., they are not being optimized or considered as free parameters). To revert back to using computing the matrix automatically withinjx.integrate(), runmodule.customize_solver_exp_euler(exp_euler_transition=None).
Example usage
Optimize solver speed by pre-computing the exponential Euler transition matrix:
delta_t = 0.025 cell = jx.Cell() cell.customize_solver_exp_euler( exp_euler_transition=cell.build_exp_euler_transition_matrix(delta_t) ) v = jx.integrate(cell, delta_t=delta_t, t_max=100.0)
- data_clamp(state_name, state_array, data_clamps=None, verbose=False)#
Insert a clamp into the module within jit (or grad).
- Parameters:
state_name (str) – Name of the state variable to set.
state_array (Array | ndarray | bool | number | bool | int | float | complex) – Time series of the state variable in the default Jaxley unit. State array should be of shape (num_clamps, simulation_time) or (simulation_time, ) for a single clamp.
verbose (bool) – Whether or not to print the number of inserted clamps. False by default because this method is meant to be jitted.
data_clamps (tuple[Array | ndarray | bool | number | bool | int | float | complex, DataFrame] | None)
- data_set(key, val, param_state)#
Set parameter of module (or its view) to a new value within jit.
- Parameters:
key (str) – The name of the parameter to set.
val (float | Array | ndarray | bool | number | bool | int | complex) – The value to set the parameter to. If it is ArrayLike then it must be of shape (len(num_compartments)).
param_state (list[dict] | None) – State of the set parameters, internally used such that this function does not modify global state.
- data_stimulate(current, data_stimuli=None, verbose=False)#
Insert a stimulus into the module within jit (or grad).
- Parameters:
current (Array | ndarray | bool | number | bool | int | float | complex) – Current in nA.
verbose (bool) – Whether or not to print the number of inserted stimuli. False by default because this method is meant to be jitted.
data_stimuli (tuple[Array | ndarray | bool | number | bool | int | float | complex, DataFrame] | None)
- Return type:
- delete(channel)#
Remove a channel or pump from the module.
Example usage
The example below inserts two channels and then deletes the sodium channel:
from jaxley.channels import Na, K cell = jx.Cell() cell.insert(Na()) cell.insert(K()) cell.delete(Na())
The example below inserts two channels and then deletes both of them:
from jaxley.channels import Na, K cell = jx.Cell() cell.insert(Na()) cell.insert(K()) # Loop over all channels and delete each one. Note: The `list(...)` is # important here because `.delete()` modifies the `.channels`. for channel in list(net.channels): cell.delete(channel)
- delete_clamps(state_name=None)#
Removes all clamps of the given state from the module.
- Parameters:
state_name (str | None)
- delete_diffusion(state)#
Deletes ion diffusion in the entire module.
- Parameters:
state (str) – Name of the state that should no longer be diffused.
- Return type:
None
- delete_recordings()#
Removes all recordings from the module.
- delete_stimuli()#
Removes all stimuli from the module.
- delete_trainables()#
Removes all trainable parameters from the module.
- diffuse(state)#
Diffuse a particular state across compartments with Fickian diffusion.
- Parameters:
state (str) – Name of the state that should be diffused.
- Return type:
None
Example usage
Diffuse calicum ions across a cell:
import jaxley as jx from jaxley.pumps import CaNernstReversal comp = jx.Compartment() branch = jx.Branch(comp, ncomp=2) cell = jx.Cell(branch, parents=[-1, 0]) cell.insert(CaNernstReversal()) cell.diffuse("CaCon_i") # Diffuse calcium ions through the cell cell.branch(0).set("CaCon_i", 0.2) cell.record("CaCon_i") simulated_concentrations = jx.integrate(cell, t_max=5.0)
- edge(idx)#
Return a View of the module at the selected synapse edges(s).
- Parameters:
idx (Any) – index of the edge to view.
- Returns:
View of the module at the specified edge index.
- Return type:
View
- get_parameters()#
Get all trainable parameters.
The returned parameters should be passed to
jx.integrate(..., params=params).- Returns:
- A list of all trainable parameters in the form of
[{“gNa”: jnp.array([0.1, 0.2, 0.3])}, …].
- Return type:
Example usage
import jaxley as jx cell = jx.Cell() cell.make_trainable("radius") params = module.get_parameters() v = jx.integrate(cell, params=params, t_max=10.0)
- init_params()#
Run channel.init_params() to initialize parameters.
- init_states(delta_t=0.025)#
Initialize all mechanisms in their steady state.
This considers the voltages and parameters of each compartment.
- Parameters:
delta_t (float) – Passed on to channel.init_state().
- initialize()#
Initialize the module.
This function does several things: 1) It computes local indices in the .nodes dataframe (from global indices). 2) It builds the compartment graph (._comp_edges and ._branchpoints). 3) It initializes the View. 4) It initializes all solvers required for solving the differential equation.
This function should be run whenever the graph-structure (i.e., the morphology or the compartmentalization) of the module have been changed. Inbuilt functions such as morph_attach(), morph_delete(), or set_ncomp() run this function automatically though, so there is no need for the user to run it manually.
- insert(channel)#
Insert a channel or pump into the module.
- loc(at)#
Return a View of the module at the selected branch location(s).
- Parameters:
at (Any) – location along the branch.
- Returns:
View of the module at the specified branch location.
- Return type:
View
- make_trainable(key, init_val=None, verbose=True)#
Make a parameter trainable.
If a parameter is made trainable, it will be returned by get_parameters() and should then be passed to jx.integrate(…, params=params).
- Parameters:
key (str) – Name of the parameter to make trainable.
init_val (float | list | None) – Initial value of the parameter. If float, the same value is used for every created parameter. If list, the length of the list has to match the number of created parameters. If None, the current parameter value is used and if parameter sharing is performed that the current parameter value is averaged over all shared parameters.
verbose (bool) – Whether to print the number of parameters that are added and the total number of parameters.
Example usage
Making a channel parameter in a compartment trainable:
comp = jx.Compartment() comp.make_trainable("radius") parameters = comp.get_parameters() # -> [{'radius': Array([1.], dtype=float32)}]
- move(x=0.0, y=0.0, z=0.0, update_nodes=False)#
Move cells or networks by adding to their (x, y, z) coordinates.
This function is used only for visualization. It does not affect the simulation.
- Parameters:
x (float) – The amount to move in the x direction in um.
y (float) – The amount to move in the y direction in um.
z (float) – The amount to move in the z direction in um.
update_nodes (bool) – Whether .nodes should be updated or not. Setting this to False largely speeds up moving, especially for big networks, but .nodes or .show will not show the new xyz coordinates.
Example usage
Move an entire cell, which moves its branches accordingly:
comp = jx.Compartment() branch = jx.Branch([comp]) cell = jx.Cell([branch] * 3, parents=[-1,0,0]) cell.move(20.0, 30.0, 5.0)
- move_to(x=0.0, y=0.0, z=0.0, update_nodes=False)#
Move cells or networks to a location (x, y, z).
If x, y, and z are floats, then the first compartment of the first branch of the first cell is moved to that float coordinate, and everything else is shifted by the difference between that compartment’s previous coordinate and the new float location.
If x, y, and z are arrays, then they must each have a length equal to the number of cells being moved. Then the first compartment of the first branch of each cell is moved to the specified location.
- Parameters:
Example usage
Move an entire cell to a specified location.:
comp = jx.Compartment() branch = jx.Branch([comp]) cell = jx.Cell([branch] * 3, parents=[-1,0,0]) cell.move_to(20.0, 30.0, 5.0)
- recording(state='v')#
Returns all available recordings in the viewed part of the module.
This function can only run after
module.write_recordings()has been run.- Parameters:
recordings – The state that should be returned.
state (str)
Example usage
comp = jx.Compartment() branch = jx.Branch(comp, ncomp=2) cell = jx.Cell(branch, parents=[-1, 0]) cell.record("v") v = jx.integrate(cell, t_max=10.0) cell.write_recordings(v) cell.branch(0).comp(0).recording("v")
- rotate(degrees, rotation_axis='xy', update_nodes=False)#
Rotate jaxley modules clockwise. Used only for visualization.
This function is used only for visualization. It does not affect the simulation.
- scope(scope)#
Return a View of the module with the specified scope.
For example cell.scope(“global”).branch(2).scope(“local”).comp(1) will return the 1st compartment of branch 2.
- Parameters:
scope (str) – either “global” or “local”.
- Returns:
View with the specified scope.
- Return type:
View
Example usage
Access the 0-th compartment of the 2nd branch of a cell:
comp = jx.Compartment() branch = jx.Branch(comp, ncomp=3) cell = jx.Cell(branch, parents=[-1, 0, 0]) cell.scope("global").branch(2).scope("local").comp(0).insert(K())
Access the sixth (global) compartment of the cell:
cell.scope("global").comp(6).insert(Na())
Note in both cases we are inserting into the same compartment. Since there are 3 compartments per branch, the global index of the first compartment in the third branch is six. Locally, the first compartment is naturally 0.
- select(nodes=None, edges=None, sorted=False)#
Return View of the module filtered by specific node or edges indices.
The selection is made based on the index of the self.nodes or self.edges, i.e., not on a local compartment index or a local row number (loc, not iloc).
- Parameters:
- Returns:
View for subset of selected nodes and/or edges.
- Return type:
View
- set(key, val)#
Set parameter of module (or its view) to a new value.
Note that this function can not be called within jax.jit or jax.grad. Instead, it should be used set the parameters of the module before the simulation. Use .data_set() to set parameters during jax.jit or jax.grad.
- Parameters:
Example usage
Setting the sodium maximal conductance of a compartment:
comp = jx.Compartment() comp.insert(Na()) comp.set("Na_gNa", 0.008)
Setting the parameter of a synapse for all synapses within a network:
net.select(edges="all").set("IonotropicSynapse_gS", 5e-4)
- set_ncomp(ncomp, min_radius=None, initialize=True)#
Set the number of compartments with which the branch is discretized.
- Parameters:
ncomp (int) – The number of compartments that the branch should be discretized into.
min_radius (float | None) – Only used if the morphology was read from an SWC file. If passed the radius is capped to be at least this value.
initialize (bool) – If False, it skips the initialization stage and the user has to run it manually afterwards. This is useful when set_ncomp is run in a loop (e.g. for the d_lambda rule), where one can initialize only once after the entire loop to largely speed up computation time. If False, then the user has to run cell.initialize() manually afterwards.
- Raises:
- When there are stimuli in any compartment in the module. –
- When there are recordings in any compartment in the module. –
- When the channels of the compartments are not the same within the branch –
that is modified. –
- When the lengths of the compartments are not the same within the branch –
that is modified. –
- When the branch that is modified has compartments belonging to different –
groups. –
- Unless the morphology was read from an SWC file, when the radiuses of the –
compartments are not the same within the branch that is modified. –
Example usage
Changing how many compartments each branch in a cell consists of:
comp = jx.Compartment() branch = jx.Branch([comp] * 4) cell = jx.Cell([branch] * 3, parents=[-1,0,0]) cell.branch(0).set_ncomp(1) cell.branch(1).set_ncomp(2) cell.branch(2).set_ncomp(3) cell.branch(3).set_ncomp(4)
- set_scope(scope)#
Toggle between “global” or “local” scope.
Determines if global or local indices are used for viewing the module.
- Parameters:
scope (str) – either “global” or “local”.
Example usage
Access the 0-th compartment of the 2nd branch of a cell:
comp = jx.Compartment() branch = jx.Branch(comp, ncomp=3) cell = jx.Cell(branch, parents=[-1, 0, 0]) cell.set_scope("local") # this is also the default cell.branch(2).comp(0).insert(Na())
Access the sixth (global) compartment of the cell:
cell.set_scope("global") cell.comp(6).insert(K())
Note that we are inserting into the same compartment in both cases. Since there are 3 compartments per branch, the global index of the first compartment in the third branch is six. Locally, the first compartment is naturally 0.
- property shape: Tuple[int]#
Returns the number of submodules contained in a module.
network.shape = (num_cells, num_branches, num_compartments) cell.shape = (num_branches, num_compartments) branch.shape = (num_compartments,)
- show(param_names=None, *, indices=True, params=True, states=True, channel_names=None)#
Print detailed information about the Module or a view of it.
- Parameters:
param_names (str | List[str] | None) – The names of the parameters to show. If None, all parameters are shown.
indices (bool) – Whether to show the indices of the compartments.
params (bool) – Whether to show the parameters of the compartments.
states (bool) – Whether to show the states of the compartments.
channel_names (List[str] | None) – The names of the channels to show. If None, all channels are shown.
- Returns:
A pd.DataFrame with the requested information.
- Return type:
DataFrame
- step(u, delta_t, external_inds, externals, params, solver='bwd_euler', voltage_solver='jaxley.stone')#
One step of solving the Ordinary Differential Equation.
This function is called inside of integrate and increments the state of the module by one time step. Calls _step_channels and _step_synapse to update the states of the channels and synapses.
- Parameters:
u (dict[str, Array | ndarray | bool | number | bool | int | float | complex]) – The state of the module. voltages = u[“v”]
delta_t (float) – The time step.
external_inds (dict[str, Array | ndarray | bool | number | bool | int | float | complex]) – The indices of the external inputs.
externals (dict[str, Array | ndarray | bool | number | bool | int | float | complex]) – The external inputs.
solver (str) – The solver to use for the voltages. Either of [“bwd_euler”, “fwd_euler”, “crank_nicolson”].
voltage_solver (str) – The tridiagonal solver used to diagonalize the coefficient matrix of the ODE system. Either of [“jaxley.thomas”, “jaxley.stone”].
- Returns:
The updated state of the module.
- Return type:
- stimulate(current=None, verbose=True)#
Insert a stimulus into the compartment.
current must be a 1d array or have batch dimension of size (num_compartments, ) or (1, ). If 1d, the same stimulus is added to all compartments.
This function cannot be run during jax.jit and jax.grad. Because of this, it should only be used for static stimuli (i.e., stimuli that do not depend on the data and that should not be learned). For stimuli that depend on data (or that should be learned), please use data_stimulate().
- to_jax()#
Move .nodes to .jaxnodes.
Before the actual simulation is run (via jx.integrate), all parameters of the jx.Module are stored in .nodes (a pd.DataFrame). However, for simulation, these parameters have to be moved to be jnp.ndarrays such that they can be processed on GPU/TPU and such that the simulation can be differentiated. .to_jax() copies the .nodes to .jaxnodes.
- property view#
Return view of the module.
- vis(ax=None, color='k', dims=(0, 1), type='line', **kwargs)#
Visualize the module.
Modules can be visualized on one of the cardinal planes (xy, xz, yz) or even in 3D.
Several options are available: - line: All points from the traced morphology (xyzr), are connected with a line plot. - scatter: All traced points, are plotted as scatter points. - comp: Plots the compartmentalized morphology, including radius and shape. (shows the true compartment lengths per default, but this can be changed via the kwargs, for details see jaxley.utils.plot_utils.plot_comps). - morph: Reconstructs the 3D shape of the traced morphology. For details see jaxley.utils.plot_utils.plot_morph. Warning: For 3D plots and morphologies with many traced points this can be very slow.
- Parameters:
ax (Axes | None) – An axis into which to plot.
color (str) – The color for all branches.
dims (Tuple[int]) – Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two of them.
type (str) – The type of plot. One of [“line”, “scatter”, “comp”, “morph”].
kwargs – Keyword arguments passed to the plotting function.
- Return type:
Axes
- write_recordings(recordings)#
Write recordings returned by
jx.integrateinto the module.After having run
write_recordings(), the recordings can be accesses viamodule.recording().- Parameters:
recordings (Array) – An array of shape (N, T), where N is the number of recorded states and T is time. N must match
len(module.rec_indices). This array is usually returned byjx.integrate().
Example usage
comp = jx.Compartment() branch = jx.Branch(comp, ncomp=2) cell = jx.Cell(branch, parents=[-1, 0]) cell.record("v") v = jx.integrate(cell, t_max=10.0) cell.write_recordings(v) cell.branch(0).comp(0).recording("v")
- write_trainables(trainable_params)#
Write the trainables into .nodes and .edges.
This allows to, e.g., visualize trained networks with .vis().
- Parameters:
trainable_params (list[dict[str, Array]]) – The trainable parameters returned by get_parameters().
Example usage
Write new parameters to the model after training:
parameters = net.get_parameters() # Assume you have some training function that gives you new parameters new_parameters = train_network(net, parameters) net.write_trainables(new_parameters) print(net.nodes) # outputs nodes of the model with the new parameters