Source code for jaxley.modules.cell

# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

import jax.numpy as jnp
import numpy as np
import pandas as pd

from jaxley.modules.base import Module
from jaxley.modules.branch import Branch
from jaxley.utils.cell_utils import (
    build_branchpoint_group_inds,
    compute_children_and_parents,
    compute_children_in_level,
    compute_children_indices,
    compute_levels,
    compute_morphology_indices_in_levels,
    compute_parents_in_level,
)
from jaxley.utils.misc_utils import cumsum_leading_zero, deprecated_kwargs
from jaxley.utils.solver_utils import (
    JaxleySolveIndexer,
    comp_edges_to_indices,
    remap_index_to_masked,
)


[docs] class Cell(Module): """Cell class. This class defines a single cell that can be simulated by itself or connected with synapses to build a network. A cell is made up of several branches and supports intricate cell morphologies. """ cell_params: Dict = {} cell_states: Dict = {} def __init__( self, branches: Optional[Union[Branch, List[Branch]]] = None, parents: Optional[List[int]] = None, xyzr: Optional[List[np.ndarray]] = None, ): """Initialize a cell. Args: branches: A single branch or a list of branches that make up the cell. If a single branch is provided, then the branch is repeated `len(parents)` times to create the cell. parents: The parent branch index for each branch. The first branch has no parent and is therefore set to -1. xyzr: For every branch, the x, y, and z coordinates and the radius at the traced coordinates. Note that this is the full tracing (from SWC), not the stick representation coordinates. """ super().__init__() assert ( isinstance(branches, (Branch, List)) or branches is None ), "Only Branch or List[Branch] is allowed." if branches is not None: assert ( parents is not None ), "If `branches` is not a list then you have to set `parents`." if isinstance(branches, List): assert len(parents) == len( branches ), "Ensure equally many parents, i.e. len(branches) == len(parents)." branches = Branch() if branches is None else branches parents = [-1] if parents is None else parents if isinstance(branches, Branch): branch_list = [branches for _ in range(len(parents))] else: branch_list = branches if xyzr is not None: assert len(xyzr) == len(parents) self.xyzr = xyzr else: # For every branch (`len(parents)`), we have a start and end point (`2`) and # a (x,y,z,r) coordinate for each of them (`4`). # Since `xyzr` is only inspected at `.vis()` and because it depends on the # (potentially learned) length of every compartment, we only populate # self.xyzr at `.vis()`. self.xyzr = [float("NaN") * np.zeros((2, 4)) for _ in range(len(parents))] self.total_nbranches = len(branch_list) self.nbranches_per_cell = [len(branch_list)] self.comb_parents = jnp.asarray(parents) self.comb_children = compute_children_indices(self.comb_parents) self._cumsum_nbranches = np.asarray([0, len(branch_list)]) # Compartment structure. These arguments have to be rebuilt when `.set_ncomp()` # is run. self.ncomp_per_branch = np.asarray([branch.ncomp for branch in branch_list]) self.ncomp = int(np.max(self.ncomp_per_branch)) self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch) self._internal_node_inds = np.arange(self.cumsum_ncomp[-1]) # Build nodes. Has to be changed when `.set_ncomp()` is run. self.nodes = pd.concat([c.nodes for c in branch_list], ignore_index=True) self.nodes["global_comp_index"] = np.arange(self.cumsum_ncomp[-1]) self.nodes["global_branch_index"] = np.repeat( np.arange(self.total_nbranches), self.ncomp_per_branch ).tolist() self.nodes["global_cell_index"] = np.repeat(0, self.cumsum_ncomp[-1]).tolist() self._update_local_indices() self._init_view() # Appending general parameters (radius, length, r_a, cm) and channel parameters, # as well as the states (v, and channel states). self._append_params_and_states(self.cell_params, self.cell_states) # Channels. self._gather_channels_from_constituents(branch_list) self.branch_edges = pd.DataFrame( dict( parent_branch_index=self.comb_parents[1:], child_branch_index=np.arange(1, self.total_nbranches), ) ) # For morphology indexing. self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = ( compute_children_and_parents(self.branch_edges) ) self._initialize() def _init_morph_jaxley_spsolve(self): """Initialize morphology for the custom sparse solver. Running this function is only required for custom Jaxley solvers, i.e., for `voltage_solver={'jaxley.stone', 'jaxley.thomas'}`. However, because at `.__init__()` (when the function is run), we do not yet know which solver the user will use. Therefore, we always run this function at `.__init__()`. """ children_and_parents = compute_morphology_indices_in_levels( len(self._par_inds), self._child_belongs_to_branchpoint, self._par_inds, self._child_inds, ) branchpoint_group_inds = build_branchpoint_group_inds( len(self._par_inds), self._child_belongs_to_branchpoint, self.cumsum_ncomp[-1], ) parents = self.comb_parents children_inds = children_and_parents["children"] parents_inds = children_and_parents["parents"] levels = compute_levels(parents) children_in_level = compute_children_in_level(levels, children_inds) parents_in_level = compute_parents_in_level( levels, self._par_inds, parents_inds ) levels_and_ncomp = pd.DataFrame().from_dict( { "levels": levels, "ncomps": self.ncomp_per_branch, } ) levels_and_ncomp["max_ncomp_in_level"] = levels_and_ncomp.groupby("levels")[ "ncomps" ].transform("max") padded_cumsum_ncomp = cumsum_leading_zero( levels_and_ncomp["max_ncomp_in_level"].to_numpy() ) # Generate mapping to deal with the masking which allows using the custom # sparse solver to deal with different ncomp per branch. remapped_node_indices = remap_index_to_masked( self._internal_node_inds, self.nodes, padded_cumsum_ncomp, self.ncomp_per_branch, ) self._solve_indexer = JaxleySolveIndexer( cumsum_ncomp=padded_cumsum_ncomp, ncomp_per_branch=self.ncomp_per_branch, branchpoint_group_inds=branchpoint_group_inds, children_in_level=children_in_level, parents_in_level=parents_in_level, root_inds=np.asarray([0]), remapped_node_indices=remapped_node_indices, ) def _init_morph_jax_spsolve(self): """For morphology indexing with the `jax.sparse` voltage volver. Explanation of `self._comp_eges['type']`: `type == 0`: compartment <--> compartment (within branch) `type == 1`: branchpoint --> parent-compartment `type == 2`: branchpoint --> child-compartment `type == 3`: parent-compartment --> branchpoint `type == 4`: child-compartment --> branchpoint Running this function is only required for generic sparse solvers, i.e., for `voltage_solver='jax.sparse'`. """ # Edges between compartments within the branches. self._comp_edges = pd.concat( [ pd.DataFrame() .from_dict( { "source": list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp)) + list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp)), "sink": list(range(1 + cumsum_ncomp, ncomp + cumsum_ncomp)) + list(range(cumsum_ncomp, ncomp - 1 + cumsum_ncomp)), } ) .astype(int) for ncomp, cumsum_ncomp in zip(self.ncomp_per_branch, self.cumsum_ncomp) ] ) self._comp_edges["type"] = 0 # Edges from branchpoints to compartments. branchpoint_to_parent_edges = pd.DataFrame().from_dict( { "source": np.arange(len(self._par_inds)) + self.cumsum_ncomp[-1], "sink": self.cumsum_ncomp[self._par_inds + 1] - 1, "type": 1, } ) branchpoint_to_child_edges = pd.DataFrame().from_dict( { "source": self._child_belongs_to_branchpoint + self.cumsum_ncomp[-1], "sink": self.cumsum_ncomp[self._child_inds], "type": 2, } ) self._comp_edges = pd.concat( [ self._comp_edges, branchpoint_to_parent_edges, branchpoint_to_child_edges, ], ignore_index=True, ) # Edges from compartments to branchpoints. parent_to_branchpoint_edges = branchpoint_to_parent_edges.rename( columns={"sink": "source", "source": "sink"} ) parent_to_branchpoint_edges["type"] = 3 child_to_branchpoint_edges = branchpoint_to_child_edges.rename( columns={"sink": "source", "source": "sink"} ) child_to_branchpoint_edges["type"] = 4 self._comp_edges = pd.concat( [ self._comp_edges, parent_to_branchpoint_edges, child_to_branchpoint_edges, ], ignore_index=True, ) n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges) self._n_nodes = n_nodes self._data_inds = data_inds self._indices_jax_spsolve = indices self._indptr_jax_spsolve = indptr