Source code for jaxley.modules.branch

# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Callable, Dict, List, Optional, Tuple, Union
from warnings import warn

import jax.numpy as jnp
import numpy as np
import pandas as pd

from jaxley.modules.base import Module
from jaxley.modules.compartment import Compartment
from jaxley.utils.cell_utils import compute_children_and_parents
from jaxley.utils.misc_utils import cumsum_leading_zero, deprecated_kwargs
from jaxley.utils.solver_utils import JaxleySolveIndexer, comp_edges_to_indices


[docs] class Branch(Module): """Branch class. This class defines a single branch that can be simulated by itself or connected to build a cell. A branch is linear segment of several compartments and can be connected to no, one or more other branches at each end to build more intricate cell morphologies. """ branch_params: Dict = {} branch_states: Dict = {} def __init__( self, compartments: Optional[Union[Compartment, List[Compartment]]] = None, ncomp: Optional[int] = None, ): """ Args: compartments: A single compartment or a list of compartments that make up the branch. ncomp: Number of segments to divide the branch into. If `compartments` is an a single compartment, than the compartment is repeated `ncomp` times to create the branch. """ super().__init__() assert ( isinstance(compartments, (Compartment, List)) or compartments is None ), "Only Compartment or List[Compartment] is allowed." if isinstance(compartments, Compartment): assert ( ncomp is not None ), "If `compartments` is not a list then you have to set `ncomp`." compartments = Compartment() if compartments is None else compartments ncomp = 1 if ncomp is None else ncomp if isinstance(compartments, Compartment): compartment_list = [compartments] * ncomp else: compartment_list = compartments self.ncomp = len(compartment_list) self.ncomp_per_branch = np.asarray([self.ncomp]) self.total_nbranches = 1 self.nbranches_per_cell = [1] self._cumsum_nbranches = jnp.asarray([0, 1]) self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch) # Indexing. self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True) self._append_params_and_states(self.branch_params, self.branch_states) self.nodes["global_comp_index"] = np.arange(self.ncomp).tolist() self.nodes["global_branch_index"] = [0] * self.ncomp self.nodes["global_cell_index"] = [0] * self.ncomp self._update_local_indices() self._init_view() # Channels. self._gather_channels_from_constituents(compartment_list) self.branch_edges = pd.DataFrame( dict(parent_branch_index=[], child_branch_index=[]) ) # For morphology indexing. self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = ( compute_children_and_parents(self.branch_edges) ) self._internal_node_inds = jnp.arange(self.ncomp) self._initialize() # Coordinates. self.xyzr = [float("NaN") * np.zeros((2, 4))] def _init_morph_jaxley_spsolve(self): self._solve_indexer = JaxleySolveIndexer( cumsum_ncomp=self.cumsum_ncomp, ncomp_per_branch=self.ncomp_per_branch, branchpoint_group_inds=np.asarray([]).astype(int), remapped_node_indices=self._internal_node_inds, children_in_level=[], parents_in_level=[], root_inds=np.asarray([0]), ) def _init_morph_jax_spsolve(self): """Initialize morphology for the jax sparse voltage solver. Explanation of `self._comp_eges['type']`: `type == 0`: compartment <--> compartment (within branch) `type == 1`: branchpoint --> parent-compartment `type == 2`: branchpoint --> child-compartment `type == 3`: parent-compartment --> branchpoint `type == 4`: child-compartment --> branchpoint """ self._comp_edges = pd.DataFrame().from_dict( { "source": list(range(self.ncomp - 1)) + list(range(1, self.ncomp)), "sink": list(range(1, self.ncomp)) + list(range(self.ncomp - 1)), } ) self._comp_edges["type"] = 0 n_nodes, data_inds, indices, indptr = comp_edges_to_indices(self._comp_edges) self._n_nodes = n_nodes self._data_inds = data_inds self._indices_jax_spsolve = indices self._indptr_jax_spsolve = indptr def __len__(self) -> int: return self.ncomp