Source code for jaxley.utils.cell_utils

# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from math import pi
from typing import Callable, Dict, List, Optional, Tuple

import jax.numpy as jnp
import numpy as np
import pandas as pd
from jax import vmap

from jaxley.utils.misc_utils import cumsum_leading_zero


[docs] def build_radiuses_from_xyzr( radius_fns: List[Callable], branch_indices: List[int], min_radius: Optional[float], ncomp: int, ) -> jnp.ndarray: """Return the radiuses of branches given SWC file xyzr. Returns an array of shape `(num_branches, ncomp)`. Args: radius_fns: Functions which, given compartment locations return the radius. branch_indices: The indices of the branches for which to return the radiuses. min_radius: If passed, the radiuses are clipped to be at least as large. ncomp: The number of compartments that every branch is discretized into. """ # Compartment locations are at the center of the internal nodes. non_split = 1 / ncomp range_ = np.linspace(non_split / 2, 1 - non_split / 2, ncomp) # Build radiuses. radiuses = np.asarray([radius_fns[b](range_) for b in branch_indices]) radiuses_each = radiuses.ravel(order="C") if min_radius is None: assert np.all( radiuses_each > 0.0 ), "Radius 0.0 in SWC file. Set `read_swc(..., min_radius=...)`." else: radiuses_each[radiuses_each < min_radius] = min_radius return radiuses_each
[docs] def equal_segments(branch_property: list, ncomp_per_branch: int): """Generates segments where some property is the same in each segment. Args: branch_property: List of values of the property in each branch. Should have `len(branch_property) == num_branches`. """ assert isinstance(branch_property, list), "branch_property must be a list." return jnp.asarray([branch_property] * ncomp_per_branch).T
[docs] def linear_segments( initial_val: float, endpoint_vals: list, parents: jnp.ndarray, ncomp_per_branch: int ): """Generates segments where some property is linearly interpolated. Args: initial_val: The value at the tip of the soma. endpoint_vals: The value at the endpoints of each branch. """ branch_property = endpoint_vals + [initial_val] num_branches = len(parents) # Compute radiuses by linear interpolation. endpoint_radiuses = jnp.asarray(branch_property) def compute_rad(branch_ind, loc): start = endpoint_radiuses[parents[branch_ind]] end = endpoint_radiuses[branch_ind] return (end - start) * loc + start branch_inds_of_each_comp = jnp.tile(jnp.arange(num_branches), ncomp_per_branch) locs_of_each_comp = jnp.linspace(1, 0, ncomp_per_branch).repeat(num_branches) rad_of_each_comp = compute_rad(branch_inds_of_each_comp, locs_of_each_comp) return jnp.reshape(rad_of_each_comp, (ncomp_per_branch, num_branches)).T
[docs] def merge_cells( cumsum_num_branches: List[int], cumsum_num_branchpoints: List[int], arrs: List[List[np.ndarray]], exclude_first: bool = True, ) -> np.ndarray: """ Build full list of which branches are solved in which iteration. From the branching pattern of single cells, this "merges" them into a single ordering of branches. Args: cumsum_num_branches: cumulative number of branches. E.g., for three cells with 10, 15, and 5 branches respectively, this will should be a list containing `[0, 10, 25, 30]`. arrs: A list of a list of arrays that should be merged. exclude_first: If `True`, the first element of each list in `arrs` will remain unchanged. Useful if a `-1` (which indicates "no parent") entry should not be changed. Returns: A list of arrays which contain the branch indices that are computed at each level (i.e., iteration). """ ps = [] for i, att in enumerate(arrs): p = att if exclude_first: raise NotImplementedError p = [p[0]] + [p_in_level + cumsum_num_branches[i] for p_in_level in p[1:]] else: p = [ p_in_level + np.asarray([cumsum_num_branches[i], cumsum_num_branchpoints[i]]) for p_in_level in p ] ps.append(p) max_len = max([len(att) for att in arrs]) combined_parents_in_level = [] for i in range(max_len): current_ps = [] for p in ps: if len(p) > i: current_ps.append(p[i]) combined_parents_in_level.append(np.concatenate(current_ps)) return combined_parents_in_level
def compute_levels(parents): levels = np.zeros_like(parents) for i, p in enumerate(parents): if p == -1: levels[i] = 0 else: levels[i] = levels[p] + 1 return levels def compute_children_in_level( levels: np.ndarray, children_row_and_col: np.ndarray ) -> List[np.ndarray]: num_branches = len(levels) children_in_each_level = [] for l in range(1, np.max(levels) + 1): children_in_current_level = [] for b in range(num_branches): if levels[b] == l: children_in_current_level.append(children_row_and_col[b - 1]) children_in_current_level = np.asarray(children_in_current_level) children_in_each_level.append(children_in_current_level) return children_in_each_level def compute_parents_in_level(levels, par_inds, parents_row_and_col): level_of_parent = levels[par_inds] parents_in_each_level = [] for l in range(np.max(levels)): parents_inds_in_current_level = np.where(level_of_parent == l)[0] parents_in_current_level = parents_row_and_col[parents_inds_in_current_level] parents_in_current_level = np.asarray(parents_in_current_level) parents_in_each_level.append(parents_in_current_level) return parents_in_each_level def _compute_num_children(parents): num_branches = len(parents) num_children = [] for b in range(num_branches): n = np.sum(np.asarray(parents) == b) num_children.append(n) return num_children def _compute_index_of_child(parents): """For every branch, it returns the how many-eth child of its parent it is. Example: ``` parents = [-1, 0, 0, 1, 1, 1] _compute_index_of_child(parents) -> [-1, 0, 1, 0, 1, 2] ``` """ num_branches = len(parents) current_num_children_for_each_branch = np.zeros((num_branches,), np.dtype("int")) index_of_child = [-1] for b in range(1, num_branches): index_of_child.append(current_num_children_for_each_branch[parents[b]]) current_num_children_for_each_branch[parents[b]] += 1 return index_of_child
[docs] def compute_children_indices(parents) -> List[jnp.ndarray]: """Return all children indices of every branch. Example: ``` parents = [-1, 0, 0] compute_children_indices(parents) -> [[1, 2], [], []] ``` """ num_branches = len(parents) child_indices = [] for b in range(num_branches): child_indices.append(np.where(parents == b)[0]) return child_indices
[docs] def get_num_neighbours( num_children: jnp.ndarray, ncomp_per_branch: int, num_branches: int, ): """ Number of neighbours of each compartment. """ num_neighbours = 2 * jnp.ones((num_branches * ncomp_per_branch)) num_neighbours = num_neighbours.at[ncomp_per_branch - 1].set(1.0) num_neighbours = num_neighbours.at[jnp.arange(num_branches) * ncomp_per_branch].set( num_children + 1.0 ) return num_neighbours
[docs] def local_index_of_loc( loc: float, global_branch_ind: int, ncomp_per_branch: int ) -> int: """Returns the local index of a comp given a loc [0, 1] and the index of a branch. This is used because we specify locations such as synapses as a value between 0 and 1. We have to convert this onto a discrete segment here. Args: branch_ind: Index of the branch. loc: Location (in [0, 1]) along that branch. ncomp_per_branch: Number of segments of each branch. Returns: The local index of the compartment. """ ncomp = ncomp_per_branch[global_branch_ind] # only for convenience. possible_locs = np.linspace(0.5 / ncomp, 1 - 0.5 / ncomp, ncomp) ind_along_branch = np.argmin(np.abs(possible_locs - loc)) return ind_along_branch
[docs] def loc_of_index(global_comp_index, global_branch_index, ncomp_per_branch): """Return location corresponding to global compartment index.""" cumsum_ncomp = cumsum_leading_zero(ncomp_per_branch) index = global_comp_index - cumsum_ncomp[global_branch_index] ncomp = ncomp_per_branch[global_branch_index] return (0.5 + index) / ncomp
[docs] def compute_g_long(rad1, rad2, g_a1, g_a2, l1, l2): """Return the axial conductance between two compartments. Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`. The axial conductance is: g_long = 2 * pi * rad1^2 * rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2) Here, we define `g_a = 1/r_a`, because g_a can be zero (but not infinity as this would be inherently unstable). """ return ( 2 * pi * rad1**2 * rad2**2 * g_a1 * g_a2 / (l1 * g_a2 * rad2**2 + l2 * g_a1 * rad1**2) )
[docs] def g_long_by_surface_area(rad1, rad2, g_a1, g_a2, l1, l2): """Return the voltage coupling conductance between two compartments. Equations taken from `https://en.wikipedia.org/wiki/Compartmental_neuron_models`. The axial resistivity is: g_long = 2 * pi * rad1^2 * rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2) For voltage, we have to divide the axial conductance by the surface are of the sink, i.e. by A = 2 * pi * rad1 * l1 By that, we get: g_axial = rad1 * rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2) / l1 Here, we define `g_a = 1/r_a`, because g_a can be zero (but not infinity as this would be inherently unstable). `radius`: um `g_a`: Siemens / cm `r_a`: ohm cm (unused, just for reference) `length_single_compartment`: um """ g_long = compute_g_long(rad1, rad2, g_a1, g_a2, l1, l2) surface_area = 2 * pi * rad1 * l1 return g_long / surface_area
[docs] def g_long_by_volume(rad1, rad2, g_a1, g_a2, l1, l2): """Return the ion diffusive constant between two compartments. The axial resistivity is: g_long = 2 * pi * rad1^2 rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2) For ions, we have to divide the axial conductance by the volume of the sink, i.e. by V = pi * rad1^2 * l1 This gives: g_axial = 2 * rad2^2 / (l1 * r_a1 * rad2^2 + l2 * r_a2 * rad1^2) / l1 Expressed in conductances g_a (not r_a), this gives: g_axial = 2 * rad2^2 * g_a1 * g_a2 / (l1 * g_a2 * rad2^2 + l2 * g_a1 * rad1^2) / l1 But here, we define `g = 1/r_a`, because g can be zero (but not infinity as this would be inherently unstable). In particular, one might want g=0 for ion diffusion. `radius`: um `g_a`: mM / liter / cm `l`: um """ g_long = compute_g_long(rad1, rad2, g_a1, g_a2, l1, l2) volume = pi * rad1**2 * l1 return g_long / volume
[docs] def compute_impact_on_node(rad, g_a, l): r"""Compute the weight with which a compartment influences its node. In order to satisfy Kirchhoffs current law, the current at a branch point must be proportional to the crosssection of the compartment. We only require proportionality here because the branch point equation reads: `g_1 * (V_1 - V_b) + g_2 * (V_2 - V_b) = 0.0` Because R_long = r_a * L/2 / crosssection, we get g_long = crosssection * 2 / L / r_a \propto rad**2 / L / r_a Finally, we define `g_a = 1 / r_a` (in order to allow `r_a=inf`, or `g_a=0`). This equation can be multiplied by any constant.""" return rad**2 * g_a / l
[docs] def remap_to_consecutive(arr): """Maps an array of integers to an array of consecutive integers. E.g. `[0, 0, 1, 4, 4, 6, 6] -> [0, 0, 1, 2, 2, 3, 3]` """ _, inverse_indices = jnp.unique(arr, return_inverse=True) return inverse_indices
v_interp = vmap(jnp.interp, in_axes=(None, None, 1))
[docs] def interpolate_xyzr(loc: float, coords: np.ndarray): """Perform a linear interpolation between xyz-coordinates. Args: loc: The location in [0,1] along the branch. coords: Array containing the reconstructed xyzr points of the branch. Return: Interpolated xyz coordinate at `loc`, shape `(3,). """ dl = np.sqrt(np.sum(np.diff(coords[:, :3], axis=0) ** 2, axis=1)) pathlens = np.insert(np.cumsum(dl), 0, 0) # cummulative length of sections norm_pathlens = pathlens / np.maximum(1e-8, pathlens[-1]) # norm lengths to [0,1]. return v_interp(loc, norm_pathlens, coords)
[docs] def params_to_pstate( params: List[Dict[str, jnp.ndarray]], indices_set_by_trainables: List[jnp.ndarray], ): """Make outputs `get_parameters()` conform with outputs of `.data_set()`. `make_trainable()` followed by `params=get_parameters()` does not return indices because these indices would also be differentiated by `jax.grad` (as soon as the `params` are passed to `def simulate(params)`. Therefore, in `jx.integrate`, we run the function to add indices to the dict. The outputs of `params_to_pstate` are of the same shape as the outputs of `.data_set()`.""" return [ {"key": list(p.keys())[0], "val": list(p.values())[0], "indices": i} for p, i in zip(params, indices_set_by_trainables) ]
[docs] def convert_point_process_to_distributed( current: jnp.ndarray, radius: jnp.ndarray, length: jnp.ndarray ) -> jnp.ndarray: """Convert current point process (nA) to distributed current (uA/cm2). This function gets called for synapses and for external stimuli. Args: current: Current in `nA`. radius: Compartment radius in `um`. length: Compartment length in `um`. Return: Current in `uA/cm2`. """ area = 2 * pi * radius * length current /= area # nA / um^2 return current * 100_000 # Convert (nA / um^2) to (uA / cm^2)
def build_branchpoint_group_inds( num_branchpoints, child_belongs_to_branchpoint, start_ind_for_branchpoints ): branchpoint_inds_parents = start_ind_for_branchpoints + jnp.arange(num_branchpoints) branchpoint_inds_children = ( start_ind_for_branchpoints + child_belongs_to_branchpoint ) all_branchpoint_inds = jnp.concatenate( [branchpoint_inds_parents, branchpoint_inds_children] ) branchpoint_group_inds = remap_to_consecutive(all_branchpoint_inds) return branchpoint_group_inds
[docs] def compute_morphology_indices_in_levels( num_branchpoints, child_belongs_to_branchpoint, par_inds, child_inds, ): """Return (row, col) to build the sparse matrix defining the voltage eqs. This is run at `init`, not during runtime. """ branchpoint_inds_parents = jnp.arange(num_branchpoints) branchpoint_inds_children = child_belongs_to_branchpoint branch_inds_parents = par_inds branch_inds_children = child_inds children = jnp.stack([branch_inds_children, branchpoint_inds_children]) parents = jnp.stack([branch_inds_parents, branchpoint_inds_parents]) return {"children": children.T, "parents": parents.T}
[docs] def group_and_sum( values_to_sum: jnp.ndarray, inds_to_group_by: jnp.ndarray, num_branchpoints: int ) -> jnp.ndarray: """Group values by whether they have the same integer and sum values within group. This is used to construct the last diagonals at the branch points. Written by ChatGPT. """ # Initialize an array to hold the sum of each group group_sums = jnp.zeros(num_branchpoints) # `.at[inds]` requires that `inds` is not empty, so we need an if-case here. # `len(inds) == 0` is the case for branches and compartments. if num_branchpoints > 0: group_sums = group_sums.at[inds_to_group_by].add(values_to_sum) return group_sums
[docs] def query_channel_states_and_params(d, keys, idcs): """Get dict with subset of keys and values from d. This is used to restrict a dict where every item contains __all__ states to only the ones that are relevant for the channel. E.g. ```states = {'eCa': Array([ 0., 0., nan]}``` will be ```states = {'eCa': Array([ 0., 0.]}``` Only loops over necessary keys, as opposed to looping over `d.items()`.""" return dict(zip(keys, (v[idcs] for v in map(d.get, keys))))
[docs] def compute_axial_conductances( comp_edges: pd.DataFrame, params: Dict[str, jnp.ndarray], diffusion_states: List[str], ) -> Dict[str, jnp.ndarray]: """Given `comp_edges`, radius, length, r_a, cm, compute the axial conductances. Note that the resulting axial conductances will already by divided by the capacitance `cm`. """ # `Compartment-to-compartment` (c2c) axial coupling conductances. condition = comp_edges["type"].to_numpy() == 0 source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list()) sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list()) axial_conductances = jnp.stack( [1 / params["axial_resistivity"]] + [params[f"axial_diffusion_{d}"] for d in diffusion_states] ) if len(sink_comp_inds) > 0: # For voltages, divide by the surface area. conds_c2c = vmap( vmap(g_long_by_surface_area, in_axes=(0, 0, 0, 0, 0, 0)), in_axes=(None, None, 0, 0, None, None), )( params["radius"][sink_comp_inds], params["radius"][source_comp_inds], axial_conductances[:1, sink_comp_inds], axial_conductances[:1, source_comp_inds], params["length"][sink_comp_inds], params["length"][source_comp_inds], ) # .at[0] because we only divide the axial voltage conductances by the # capacitance, _not_ the axial conductances of the diffusing ions. conds_c2c = conds_c2c.at[0].divide(params["capacitance"][sink_comp_inds]) # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2). conds_c2c = conds_c2c.at[0].multiply(10**7) # For ion diffusion, we have to divide by the volume, not the surface area. conds_diffusion = vmap( vmap(g_long_by_volume, in_axes=(0, 0, 0, 0, 0, 0)), in_axes=(None, None, 0, 0, None, None), )( params["radius"][sink_comp_inds], params["radius"][source_comp_inds], axial_conductances[1:, sink_comp_inds], axial_conductances[1:, source_comp_inds], params["length"][sink_comp_inds], params["length"][source_comp_inds], ) conds_c2c = jnp.concatenate([conds_c2c, conds_diffusion]) else: conds_c2c = jnp.asarray([[]] * (len(diffusion_states) + 1)) # `branchpoint-to-compartment` (bp2c) axial coupling conductances. condition = comp_edges["type"].isin([1, 2]) sink_comp_inds = np.asarray(comp_edges[condition]["sink"].to_list()) if len(sink_comp_inds) > 0: # For voltages, divide by the surface area. conds_bp2c = vmap( vmap(g_long_by_surface_area, in_axes=(0, 0, 0, 0, 0, 0)), in_axes=(None, None, 0, 0, None, None), )( params["radius"][sink_comp_inds], params["radius"][sink_comp_inds], axial_conductances[:1, sink_comp_inds], axial_conductances[:1, sink_comp_inds], params["length"][sink_comp_inds], jnp.zeros_like(params["length"][sink_comp_inds]), # l=0 for branchpoint. ) # .at[0] because we only divide the axial voltage conductances by the # capacitance, _not_ the axial conductances of the diffusing ions. conds_bp2c = conds_bp2c.at[0].divide(params["capacitance"][sink_comp_inds]) # Multiply by 10**7 to convert (S / cm / um) -> (mS / cm^2). conds_bp2c = conds_bp2c.at[0].multiply(10**7) # For ion diffusion, we have to divide by the volume, not the surface area. conds_bp2c_diffusion = vmap( vmap(g_long_by_volume, in_axes=(0, 0, 0, 0, 0, 0)), in_axes=(None, None, 0, 0, None, None), )( params["radius"][sink_comp_inds], params["radius"][sink_comp_inds], axial_conductances[1:, sink_comp_inds], axial_conductances[1:, sink_comp_inds], params["length"][sink_comp_inds], jnp.zeros_like(params["length"][sink_comp_inds]), # l=0 for branchpoint. ) conds_bp2c = jnp.concatenate([conds_bp2c, conds_bp2c_diffusion]) else: conds_bp2c = jnp.asarray([[]] * (len(diffusion_states) + 1)) # `compartment-to-branchpoint` (c2bp) axial coupling conductances. condition = comp_edges["type"].isin([3, 4]) source_comp_inds = np.asarray(comp_edges[condition]["source"].to_list()) if len(source_comp_inds) > 0: conds_c2bp = vmap( vmap(compute_impact_on_node, in_axes=(0, 0, 0)), in_axes=(None, 0, None) )( params["radius"][source_comp_inds], axial_conductances[:, source_comp_inds], params["length"][source_comp_inds], ) # For numerical stability. These values are very small, but their scale # does not matter. conds_c2bp *= 1_000 else: conds_c2bp = jnp.asarray([[]] * (len(diffusion_states) + 1)) # All axial coupling conductances. all_coupling_conds = jnp.concatenate([conds_c2c, conds_bp2c, conds_c2bp], axis=1) conds_as_dict = {} for i, key in enumerate(["v"] + diffusion_states): conds_as_dict[key] = all_coupling_conds[i] return conds_as_dict
[docs] def compute_children_and_parents( branch_edges: pd.DataFrame, ) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int]: """Build indices used during `._init_morph_custom_spsolve().""" par_inds = branch_edges["parent_branch_index"].to_numpy() child_inds = branch_edges["child_branch_index"].to_numpy() child_belongs_to_branchpoint = remap_to_consecutive(par_inds) par_inds = np.unique(par_inds) return par_inds, child_inds, child_belongs_to_branchpoint