Utils#

build_radiuses_from_xyzr(radius_fns, branch_indices, min_radius, ncomp)[source]#

Return the radiuses of branches given SWC file xyzr.

Returns an array of shape (num_branches, ncomp).

Parameters:
  • radius_fns (List[Callable]) – Functions which, given compartment locations return the radius.

  • branch_indices (List[int]) – The indices of the branches for which to return the radiuses.

  • min_radius (float | None) – If passed, the radiuses are clipped to be at least as large.

  • ncomp (int) – The number of compartments that every branch is discretized into.

Return type:

Array

equal_segments(branch_property, ncomp_per_branch)[source]#

Generates segments where some property is the same in each segment.

Parameters:
  • branch_property (list) – List of values of the property in each branch. Should have len(branch_property) == num_branches.

  • ncomp_per_branch (int)

linear_segments(initial_val, endpoint_vals, parents, ncomp_per_branch)[source]#

Generates segments where some property is linearly interpolated.

Parameters:
  • initial_val (float) – The value at the tip of the soma.

  • endpoint_vals (list) – The value at the endpoints of each branch.

  • parents (Array)

  • ncomp_per_branch (int)

merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)[source]#

Build full list of which branches are solved in which iteration.

From the branching pattern of single cells, this “merges” them into a single ordering of branches.

Parameters:
  • cumsum_num_branches (List[int]) – cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30].

  • arrs (List[List[ndarray]]) – A list of a list of arrays that should be merged.

  • exclude_first (bool) – If True, the first element of each list in arrs will remain unchanged. Useful if a -1 (which indicates “no parent”) entry should not be changed.

  • cumsum_num_branchpoints (List[int])

Returns:

A list of arrays which contain the branch indices that are computed at each level (i.e., iteration).

Return type:

ndarray

compute_children_indices(parents)[source]#

Return all children indices of every branch.

Example: ` parents = [-1, 0, 0] compute_children_indices(parents) -> [[1, 2], [], []] `

Return type:

List[Array]

get_num_neighbours(num_children, ncomp_per_branch, num_branches)[source]#

Number of neighbours of each compartment.

Parameters:
  • num_children (Array)

  • ncomp_per_branch (int)

  • num_branches (int)

local_index_of_loc(loc, global_branch_ind, ncomp_per_branch)[source]#

Returns the local index of a comp given a loc [0, 1] and the index of a branch.

This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.

Parameters:
  • branch_ind – Index of the branch.

  • loc (float) – Location (in [0, 1]) along that branch.

  • ncomp_per_branch (int) – Number of segments of each branch.

  • global_branch_ind (int)

Returns:

The local index of the compartment.

Return type:

int

loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch)[source]#

Return location corresponding to global compartment index.

compute_g_long(rad1, rad2, g_a1, g_a2, l1, l2)[source]#

Return the axial conductance between two compartments.

Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models.

The axial conductance is: g_long = 2 * pi * rad1^2 * rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2)

Here, we define g_a = 1/r_a, because g_a can be zero (but not infinity as this would be inherently unstable).

g_long_by_surface_area(rad1, rad2, g_a1, g_a2, l1, l2)[source]#

Return the voltage coupling conductance between two compartments.

Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models.

The axial resistivity is: g_long = 2 * pi * rad1^2 * rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2)

For voltage, we have to divide the axial conductance by the surface are of the sink, i.e. by A = 2 * pi * rad1 * l1

By that, we get: g_axial = rad1 * rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2) / l1

Here, we define g_a = 1/r_a, because g_a can be zero (but not infinity as this would be inherently unstable).

radius: um g_a: Siemens / cm r_a: ohm cm (unused, just for reference) length_single_compartment: um

g_long_by_volume(rad1, rad2, g_a1, g_a2, l1, l2)[source]#

Return the ion diffusive constant between two compartments.

The axial resistivity is: g_long = 2 * pi * rad1^2 rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2)

For ions, we have to divide the axial conductance by the volume of the sink, i.e. by V = pi * rad1^2 * l1

This gives: g_axial = 2 * rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2) / l1

Expressed in conductances g_a (not r_a), this gives: g_axial = 2 * rad2^2 * g_a1 * g_a2 / (l1 * g_a2 * rad2^2 + l2 * g_a1 * rad1^2) / l1

But here, we define g = 1/r_a, because g can be zero (but not infinity as this would be inherently unstable). In particular, one might want g=0 for ion diffusion.

radius: um g_a: mM / liter / cm l: um

compute_impact_on_node(rad, g_a, l)[source]#

Compute the weight with which a compartment influences its node.

In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0

Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a propto rad**2 / L / r_a

Finally, we define g_a = 1 / r_a (in order to allow r_a=inf, or g_a=0).

This equation can be multiplied by any constant.

remap_to_consecutive(arr)[source]#

Maps an array of integers to an array of consecutive integers.

E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]

interpolate_xyzr(loc, coords)[source]#

Perform a linear interpolation between xyz-coordinates.

Parameters:
  • loc (float) – The location in [0,1] along the branch.

  • coords (ndarray) – Array containing the reconstructed xyzr points of the branch.

Returns:

Interpolated xyz coordinate at loc, shape `(3,).

params_to_pstate(params, indices_set_by_trainables)[source]#

Make outputs get_parameters() conform with outputs of .data_set().

make_trainable() followed by params=get_parameters() does not return indices because these indices would also be differentiated by jax.grad (as soon as the params are passed to def simulate(params). Therefore, in jx.integrate, we run the function to add indices to the dict. The outputs of params_to_pstate are of the same shape as the outputs of .data_set().

Parameters:
convert_point_process_to_distributed(current, radius, length)[source]#

Convert current point process (nA) to distributed current (uA/cm2).

This function gets called for synapses and for external stimuli.

Parameters:
  • current (Array) – Current in nA.

  • radius (Array) – Compartment radius in um.

  • length (Array) – Compartment length in um.

Returns:

Current in uA/cm2.

Return type:

Array

compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)[source]#

Return (row, col) to build the sparse matrix defining the voltage eqs.

This is run at init, not during runtime.

group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)[source]#

Group values by whether they have the same integer and sum values within group.

This is used to construct the last diagonals at the branch points.

Written by ChatGPT.

Parameters:
  • values_to_sum (Array)

  • inds_to_group_by (Array)

  • num_branchpoints (int)

Return type:

Array

query_channel_states_and_params(d, keys, idcs)[source]#

Get dict with subset of keys and values from d.

This is used to restrict a dict where every item contains __all__ states to only the ones that are relevant for the channel. E.g.

`states = {'eCa': Array([ 0.,  0., nan]}`

will be `states = {'eCa': Array([ 0.,  0.]}`

Only loops over necessary keys, as opposed to looping over d.items().

compute_axial_conductances(comp_edges, params, diffusion_states)[source]#

Given comp_edges, radius, length, r_a, cm, compute the axial conductances.

Note that the resulting axial conductances will already by divided by the capacitance cm.

Parameters:
Return type:

Dict[str, Array]

compute_children_and_parents(branch_edges)[source]#

Build indices used during `._init_morph_custom_spsolve().

Parameters:

branch_edges (DataFrame)

Return type:

Tuple[Array, Array, Array, int]

plot_graph(xyzr, dims=(0, 1), color='k', ax=None, type='line', **kwargs)[source]#

Plot morphology.

Parameters:
  • xyzr (ndarray) – The coordinates of the morphology.

  • dims (Tuple[int]) – Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two or three of them.

  • color (str) – The color for all branches.

  • ax (Axes | None) – The matplotlib axis to plot on.

  • type (str) – Either line or scatter.

  • kwargs – The plot kwargs for plt.plot or plt.scatter.

Return type:

Axes

extract_outline(points)[source]#

Get the outline of a 2D/3D shape.

Extracts the subset of points which form the convex hull, i.e. the outline of the input points.

Parameters:

points (ndarray) – An array of points / corrdinates.

Returns:

An array of points which form the convex hull.

Return type:

ndarray

compute_rotation_matrix(axis, angle)[source]#

Return the rotation matrix associated with counterclockwise rotation about the given axis by the given angle.

Can be used to rotate a coordinate vector by multiplying it with the rotation matrix.

Parameters:
  • axis (ndarray) – The axis of rotation.

  • angle (float) – The angle of rotation in radians.

Returns:

A 3x3 rotation matrix.

Return type:

ndarray

create_cone_frustum_mesh(length, radius_bottom, radius_top, bottom_dome=False, top_dome=False, resolution=100)[source]#

Generates mesh points for a cone frustum, with optional domes at either end.

This is used to render the traced morphology in 3D (and to project it to 2D) as part of plot_morph. Sections between two traced coordinates with two different radii can be represented by a cone frustum. Additionally, the ends of the frustum can be capped with hemispheres to ensure that two neighbouring frustums are connected smoothly (like ball joints).

Parameters:
  • length (float) – The length of the frustum.

  • radius_bottom (float) – The radius of the bottom of the frustum.

  • radius_top (float) – The radius of the top of the frustum.

  • bottom_dome (bool) – If True, a dome is added to the bottom of the frustum. The dome is a hemisphere with radius radius_bottom.

  • top_dome (bool) – If True, a dome is added to the top of the frustum. The dome is a hemisphere with radius radius_top.

  • resolution (int) – defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

Returns:

An array of mesh points.

Return type:

ndarray

create_cylinder_mesh(length, radius, resolution=100)[source]#

Generates mesh points for a cylinder.

This is used to render cylindrical compartments in 3D (and to project it to 2D) as part of plot_comps.

Parameters:
  • length (float) – The length of the cylinder.

  • radius (float) – The radius of the cylinder.

  • resolution (int) – defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

Returns:

An array of mesh points.

Return type:

ndarray

create_sphere_mesh(radius, resolution=100)[source]#

Generates mesh points for a sphere.

This is used to render spherical compartments in 3D (and to project it to 2D) as part of plot_comps.

Parameters:
  • radius (float) – The radius of the sphere.

  • resolution (int) – defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

Returns:

An array of mesh points.

Return type:

ndarray

plot_mesh(mesh_points, orientation, center, dims, ax=None, **kwargs)[source]#

Plot the 2D projection of a volume mesh on a cardinal plane.

Project the projection of a cylinder that is oriented in 3D space. - Create cylinder mesh - rotate cylinder mesh to orient it lengthwise along a given orientation vector. - move its center - project onto plane - compute outline of projected mesh. - fill area inside the outline

Parameters:
  • mesh_points (ndarray) – coordinates of the xyz mesh that define the volume

  • orientation (ndarray) – orientation vector. The cylinder will be oriented along this vector.

  • center (ndarray) – The x,y,z coordinates of the center of the cylinder.

  • dims (Tuple[int]) – The dimensions to plot / to project the cylinder onto,

  • [0 (1] xy-plane or)

  • [0

  • 1

  • 3D. (2] for)

  • ax (Axes | None) – The matplotlib axis to plot on.

Returns:

Plot of the cylinder projection.

Return type:

Axes

plot_comps(module_or_view, dims=(0, 1), color='k', ax=None, true_comp_length=True, resolution=100, **kwargs)[source]#

Plot compartmentalized neural morphology.

Plots the projection of the cylindrical compartments.

Parameters:
  • module_or_view (jx.Module | jx.View) – The module or view to plot.

  • dims (Tuple[int]) – The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.

  • color (str) – The color for all compartments

  • ax (Axes | None) – The matplotlib axis to plot on.

  • true_comp_length (bool) – If True, the length of the compartment is used, i.e. the length of the traced neurite. This means for zig-zagging neurites the cylinders will be longer than the straight-line distance between the start and end point of the neurite. This can lead to overlapping and miss-aligned cylinders. Setting this False will use the straight-line distance instead for nicer plots.

  • resolution (int) – defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

  • kwargs – The plot kwargs for plt.fill.

Returns:

Plot of the compartmentalized morphology.

Return type:

Axes

plot_morph(module_or_view, dims=(0, 1), color='k', ax=None, resolution=100, **kwargs)[source]#

Plot the detailed morphology.

Plots the traced morphology it was traced. That means at every point that was traced a disc of radius r is plotted. The outline of the discs are then connected to form the morphology. This means every trace segement can be represented by a cone frustum. To prevent breaks in the morphology, each segement is connected with a ball joint.

Parameters:
  • module_or_view (jx.Module | jx.View) – The module or view to plot.

  • dims (Tuple[int]) – The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.

  • color (str) – The color for all branches

  • ax (Axes | None) – The matplotlib axis to plot on.

  • kwargs – The plot kwargs for plt.fill.

  • resolution (int) – defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.

Returns:

Plot of the detailed morphology.

Return type:

Axes

nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=<function scan>, checkpoint_fn=<function checkpoint>)[source]#

A version of lax.scan that supports recursive gradient checkpointing.

Code taken from: google/jax#2139

The interface of nested_checkpoint_scan exactly matches lax.scan, except for the required nested_lengths argument.

The key feature of nested_checkpoint_scan is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1 times.

nested_checkpoint_scan reduces to lax.scan when nested_lengths has a single element.

Parameters:
  • f (Callable[[Carry, Dict[str, Array]], Tuple[Carry, Output]]) – function to scan over.

  • init (Carry) – initial value.

  • xs (Dict[str, Array]) – scanned over values.

  • length (int | None) – leading length of all dimensions

  • nested_lengths (Sequence[int]) – required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in xs.

  • scan_fn – function matching the API of lax.scan

  • checkpoint_fn (Callable[[Func], Func]) – function matching the API of jax.checkpoint.

gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)[source]#

Compute current at the post synapse.

All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.

Parameters:
  • number_of_compartments (Array)

  • post_syn_comp_inds (ndarray)

  • current_each_synapse_voltage_term (Array)

  • current_each_synapse_constant_term (Array)

Return type:

Tuple[Array, Array]