Utils#
- build_radiuses_from_xyzr(radius_fns, branch_indices, min_radius, ncomp)[source]#
Return the radiuses of branches given SWC file xyzr.
Returns an array of shape (num_branches, ncomp).
- Parameters:
radius_fns (List[Callable]) – Functions which, given compartment locations return the radius.
branch_indices (List[int]) – The indices of the branches for which to return the radiuses.
min_radius (float | None) – If passed, the radiuses are clipped to be at least as large.
ncomp (int) – The number of compartments that every branch is discretized into.
- Return type:
- equal_segments(branch_property, ncomp_per_branch)[source]#
Generates segments where some property is the same in each segment.
- linear_segments(initial_val, endpoint_vals, parents, ncomp_per_branch)[source]#
Generates segments where some property is linearly interpolated.
- merge_cells(cumsum_num_branches, cumsum_num_branchpoints, arrs, exclude_first=True)[source]#
Build full list of which branches are solved in which iteration.
From the branching pattern of single cells, this “merges” them into a single ordering of branches.
- Parameters:
cumsum_num_branches (List[int]) – cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing [0, 10, 25, 30].
arrs (List[List[ndarray]]) – A list of a list of arrays that should be merged.
exclude_first (bool) – If True, the first element of each list in arrs will remain unchanged. Useful if a -1 (which indicates “no parent”) entry should not be changed.
- Returns:
A list of arrays which contain the branch indices that are computed at each level (i.e., iteration).
- Return type:
- compute_children_indices(parents)[source]#
Return all children indices of every branch.
Example:
` parents = [-1, 0, 0] compute_children_indices(parents) -> [[1, 2], [], []] `
- get_num_neighbours(num_children, ncomp_per_branch, num_branches)[source]#
Number of neighbours of each compartment.
- local_index_of_loc(loc, global_branch_ind, ncomp_per_branch)[source]#
Returns the local index of a comp given a loc [0, 1] and the index of a branch.
This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here.
- loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch)[source]#
Return location corresponding to global compartment index.
- compute_coupling_cond(rad1, rad2, r_a1, r_a2, l1, l2)[source]#
Return the coupling conductance between two compartments.
Equations taken from https://en.wikipedia.org/wiki/Compartmental_neuron_models.
radius: um r_a: ohm cm length_single_compartment: um coupling_conds: S * um / cm / um^2 = S / cm / um -> *10**7 -> mS / cm^2
- compute_coupling_cond_branchpoint(rad, r_a, l)[source]#
Return the coupling conductance between one compartment and a comp with l=0.
From https://en.wikipedia.org/wiki/Compartmental_neuron_models
If one compartment has l=0.0 then the equations simplify.
R_long = sum_i r_a * L_i/2 / crosssection_i
with crosssection = pi * r**2
For a single compartment with L>0, this turns into: R_long = r_a * L/2 / crosssection
Then, g_long = crosssection * 2 / L / r_a
Then, the effective conductance is g_long / zylinder_area. So: g = pi * r**2 * 2 / L / r_a / 2 / pi / r / L g = r / r_a / L**2
- compute_impact_on_node(rad, r_a, l)[source]#
Compute the weight with which a compartment influences its node.
In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0
Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a propto rad**2 / L / r_a
This equation can be multiplied by any constant.
- remap_to_consecutive(arr)[source]#
Maps an array of integers to an array of consecutive integers.
E.g. [0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]
- params_to_pstate(params, indices_set_by_trainables)[source]#
Make outputs get_parameters() conform with outputs of .data_set().
make_trainable() followed by params=get_parameters() does not return indices because these indices would also be differentiated by jax.grad (as soon as the params are passed to def simulate(params). Therefore, in jx.integrate, we run the function to add indices to the dict. The outputs of params_to_pstate are of the same shape as the outputs of .data_set().
- convert_point_process_to_distributed(current, radius, length)[source]#
Convert current point process (nA) to distributed current (uA/cm2).
This function gets called for synapses and for external stimuli.
- compute_morphology_indices_in_levels(num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds)[source]#
Return (row, col) to build the sparse matrix defining the voltage eqs.
This is run at init, not during runtime.
- group_and_sum(values_to_sum, inds_to_group_by, num_branchpoints)[source]#
Group values by whether they have the same integer and sum values within group.
This is used to construct the last diagonals at the branch points.
Written by ChatGPT.
- query_channel_states_and_params(d, keys, idcs)[source]#
Get dict with subset of keys and values from d.
This is used to restrict a dict where every item contains __all__ states to only the ones that are relevant for the channel. E.g.
`states = {'eCa': Array([ 0., 0., nan]}`will be
`states = {'eCa': Array([ 0., 0.]}`Only loops over necessary keys, as opposed to looping over d.items().
- compute_axial_conductances(comp_edges, params)[source]#
Given comp_edges, radius, length, r_a, cm, compute the axial conductances.
Note that the resulting axial conductances will already by divided by the capacitance cm.
- compute_children_and_parents(branch_edges)[source]#
Build indices used during `._init_morph_custom_spsolve().
- plot_graph(xyzr, dims=(0, 1), col='k', ax=None, type='line', morph_plot_kwargs={})[source]#
Plot morphology.
- Parameters:
xyzr (ndarray) – The coordinates of the morphology.
dims (Tuple[int]) – Which dimensions to plot. 1=x, 2=y, 3=z coordinate. Must be a tuple of two or three of them.
col (str) – The color for all branches.
ax (Axes | None) – The matplotlib axis to plot on.
type (str) – Either line or scatter.
morph_plot_kwargs (Dict) – The plot kwargs for plt.plot or plt.scatter.
- Return type:
Axes
- extract_outline(points)[source]#
Get the outline of a 2D/3D shape.
Extracts the subset of points which form the convex hull, i.e. the outline of the input points.
- compute_rotation_matrix(axis, angle)[source]#
Return the rotation matrix associated with counterclockwise rotation about the given axis by the given angle.
Can be used to rotate a coordinate vector by multiplying it with the rotation matrix.
- create_cone_frustum_mesh(length, radius_bottom, radius_top, bottom_dome=False, top_dome=False, resolution=100)[source]#
Generates mesh points for a cone frustum, with optional domes at either end.
This is used to render the traced morphology in 3D (and to project it to 2D) as part of plot_morph. Sections between two traced coordinates with two different radii can be represented by a cone frustum. Additionally, the ends of the frustum can be capped with hemispheres to ensure that two neighbouring frustums are connected smoothly (like ball joints).
- Parameters:
length (float) – The length of the frustum.
radius_bottom (float) – The radius of the bottom of the frustum.
radius_top (float) – The radius of the top of the frustum.
bottom_dome (bool) – If True, a dome is added to the bottom of the frustum. The dome is a hemisphere with radius radius_bottom.
top_dome (bool) – If True, a dome is added to the top of the frustum. The dome is a hemisphere with radius radius_top.
resolution (int) – defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
- Returns:
An array of mesh points.
- Return type:
- create_cylinder_mesh(length, radius, resolution=100)[source]#
Generates mesh points for a cylinder.
This is used to render cylindrical compartments in 3D (and to project it to 2D) as part of plot_comps.
- Parameters:
- Returns:
An array of mesh points.
- Return type:
- create_sphere_mesh(radius, resolution=100)[source]#
Generates mesh points for a sphere.
This is used to render spherical compartments in 3D (and to project it to 2D) as part of plot_comps.
- plot_mesh(mesh_points, orientation, center, dims, ax=None, **kwargs)[source]#
Plot the 2D projection of a volume mesh on a cardinal plane.
Project the projection of a cylinder that is oriented in 3D space. - Create cylinder mesh - rotate cylinder mesh to orient it lengthwise along a given orientation vector. - move its center - project onto plane - compute outline of projected mesh. - fill area inside the outline
- Parameters:
mesh_points (ndarray) – coordinates of the xyz mesh that define the volume
orientation (ndarray) – orientation vector. The cylinder will be oriented along this vector.
center (ndarray) – The x,y,z coordinates of the center of the cylinder.
dims (Tuple[int]) – The dimensions to plot / to project the cylinder onto,
[0 (1] xy-plane or)
[0
1
3D. (2] for)
ax (Axes | None) – The matplotlib axis to plot on.
- Returns:
Plot of the cylinder projection.
- Return type:
Axes
- plot_comps(module_or_view, dims=(0, 1), col='k', ax=None, comp_plot_kwargs={}, true_comp_length=True, resolution=100)[source]#
Plot compartmentalized neural morphology.
Plots the projection of the cylindrical compartments.
- Parameters:
module_or_view (jx.Module | jx.View) – The module or view to plot.
dims (Tuple[int]) – The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.
col (str) – The color for all compartments
ax (Axes | None) – The matplotlib axis to plot on.
comp_plot_kwargs (Dict) – The plot kwargs for plt.fill.
true_comp_length (bool) – If True, the length of the compartment is used, i.e. the length of the traced neurite. This means for zig-zagging neurites the cylinders will be longer than the straight-line distance between the start and end point of the neurite. This can lead to overlapping and miss-aligned cylinders. Setting this False will use the straight-line distance instead for nicer plots.
resolution (int) – defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
- Returns:
Plot of the compartmentalized morphology.
- Return type:
Axes
- plot_morph(module_or_view, dims=(0, 1), col='k', ax=None, resolution=100, morph_plot_kwargs={})[source]#
Plot the detailed morphology.
Plots the traced morphology it was traced. That means at every point that was traced a disc of radius r is plotted. The outline of the discs are then connected to form the morphology. This means every trace segement can be represented by a cone frustum. To prevent breaks in the morphology, each segement is connected with a ball joint.
- Parameters:
module_or_view (jx.Module | jx.View) – The module or view to plot.
dims (Tuple[int]) – The dimensions to plot / to project the cylinder onto, i.e. [0,1] xy-plane or [0,1,2] for 3D.
col (str) – The color for all branches
ax (Axes | None) – The matplotlib axis to plot on.
morph_plot_kwargs (Dict) – The plot kwargs for plt.fill.
resolution (int) – defines the resolution of the mesh. If too low (typically <10), can result in errors. Useful too have a simpler mesh for plotting.
- Returns:
Plot of the detailed morphology.
- Return type:
Axes
- nested_checkpoint_scan(f, init, xs, length=None, *, nested_lengths, scan_fn=<function scan>, checkpoint_fn=<function checkpoint>)[source]#
A version of lax.scan that supports recursive gradient checkpointing.
Code taken from: google/jax#2139
The interface of nested_checkpoint_scan exactly matches lax.scan, except for the required nested_lengths argument.
The key feature of nested_checkpoint_scan is that gradient calculations require O(max(nested_lengths)) memory, vs O(prod(nested_lengths)) for unnested scans, which it achieves by re-evaluating the forward pass len(nested_lengths) - 1 times.
nested_checkpoint_scan reduces to lax.scan when nested_lengths has a single element.
- Parameters:
f (Callable[[Carry, Dict[str, Array]], Tuple[Carry, Output]]) – function to scan over.
init (Carry) – initial value.
length (int | None) – leading length of all dimensions
nested_lengths (Sequence[int]) – required list of lengths to scan over for each level of checkpointing. The product of nested_lengths must match length (if provided) and the size of the leading axis for all arrays in
xs.scan_fn – function matching the API of lax.scan
checkpoint_fn (Callable[[Func], Func]) – function matching the API of jax.checkpoint.
- gather_synapes(number_of_compartments, post_syn_comp_inds, current_each_synapse_voltage_term, current_each_synapse_constant_term)[source]#
Compute current at the post synapse.
All this does it that it sums the synaptic currents that come into a particular compartment. It returns an array of as many elements as there are compartments.