Source code for jaxley.modules.branch
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Callable, Dict, List, Optional, Tuple, Union
from warnings import warn
import jax.numpy as jnp
import numpy as np
import pandas as pd
from jaxley.modules.base import Module
from jaxley.modules.compartment import Compartment
from jaxley.utils.cell_utils import compute_children_and_parents
from jaxley.utils.misc_utils import cumsum_leading_zero, deprecated_kwargs
[docs]
class Branch(Module):
"""A branch made up of one or multiple compartments (without branchpoints).
This class defines a single branch that can be simulated by itself or
connected to build a cell. A branch is an unbranched neurite of several
compartments and can be connected to no, one or more other branches at each end to
build more intricate cell morphologies (via ``jx.Cell``).
"""
branch_params: Dict = {}
branch_states: Dict = {}
def __init__(
self,
compartments: Optional[Union[Compartment, List[Compartment]]] = None,
ncomp: Optional[int] = None,
):
"""
Args:
compartments: A single compartment or a list of compartments that make up the
branch.
ncomp: Number of segments to divide the branch into. If `compartments` is an
a single compartment, than the compartment is repeated `ncomp` times to
create the branch.
"""
super().__init__()
assert (
isinstance(compartments, (Compartment, List)) or compartments is None
), "Only Compartment or List[Compartment] is allowed."
if isinstance(compartments, Compartment):
assert (
ncomp is not None
), "If `compartments` is not a list then you have to set `ncomp`."
compartments = Compartment() if compartments is None else compartments
ncomp = 1 if ncomp is None else ncomp
if isinstance(compartments, Compartment):
compartment_list = [compartments] * ncomp
else:
compartment_list = compartments
self.ncomp = len(compartment_list)
self.ncomp_per_branch = np.asarray([self.ncomp])
self.total_nbranches = 1
self.nbranches_per_cell = [1]
self._cumsum_nbranches = jnp.asarray([0, 1])
self.cumsum_ncomp = cumsum_leading_zero(self.ncomp_per_branch)
# Indexing.
self.nodes = pd.concat([c.nodes for c in compartment_list], ignore_index=True)
self._append_params_and_states(self.branch_params, self.branch_states)
self.nodes["global_comp_index"] = np.arange(self.ncomp).tolist()
self.nodes["global_branch_index"] = [0] * self.ncomp
self.nodes["global_cell_index"] = [0] * self.ncomp
# Channels.
self._gather_channels_from_constituents(compartment_list)
self.branch_edges = pd.DataFrame(
dict(parent_branch_index=[], child_branch_index=[])
)
# For morphology indexing.
self._par_inds, self._child_inds, self._child_belongs_to_branchpoint = (
compute_children_and_parents(self.branch_edges)
)
self._internal_node_inds = jnp.arange(self.ncomp)
# Coordinates.
self.xyzr = [float("NaN") * np.zeros((2, 4))]
self.initialize()
def _init_comp_graph(self):
"""Initialize `._comp_edges`, `._branchpoints`, and `comp_to_index_mapping`.
It also initializes `_comp_edges_in_view` and `_branchpoints_in_view`."""
# Compartment edges.
self._comp_edges = pd.DataFrame().from_dict(
{
"source": list(range(self.ncomp - 1)) + list(range(1, self.ncomp)),
"sink": list(range(1, self.ncomp)) + list(range(self.ncomp - 1)),
}
)
self._comp_edges["type"] = 0
self._n_nodes = self.ncomp
# `off_diagonal_inds`.
sources = np.asarray(self._comp_edges["source"].to_list())
sinks = np.asarray(self._comp_edges["sink"].to_list())
self._off_diagonal_inds = jnp.stack([sources, sinks]).astype(int)
def __len__(self) -> int:
return self.ncomp