Source code for jaxley.morphology.distance_utils
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Dict, List, Union
import jax.numpy as jnp
import networkx as nx
from jaxley.modules.base import to_graph
[docs]
def distance_direct(
startpoint: "View",
endpoints: Union["Branch", "Cell", "View"],
) -> List[float]:
"""Returns the direct distance between a root and other compartments.
This function uses ``cell.nodes[['x', 'y', 'z']]`` and computes the euclidean
distance (i.e., the line of sight distance).
Args:
startpoint: A single compartment from which to compute the distance.
endpoints: One or multiple compartments to which to compute the distance to.
Returns:
A list of distances.
Example usage
^^^^^^^^^^^^^
The following computes the direct (line of sight) distance between the
zero-eth soma compartment and all other compartments. It then saves this distance
in `cell.nodes["direct_dist_from_soma"]`.
::
from jaxley.morphology import distance_pathwise
cell.compute_compartment_centers() # necessary if you modified branch length.
direct_dists = distance_direct(cell.soma.branch(0).comp(0), cell)
cell.nodes["direct_dist_from_soma"] = direct_dists
"""
assert len(startpoint.nodes.index) == 1, "Cannot use multiple root nodes."
start_xyz = startpoint.nodes[["x", "y", "z"]].to_numpy()[0]
end_xyz = endpoints.nodes[["x", "y", "z"]].to_numpy()
return jnp.sqrt(jnp.sum((start_xyz - end_xyz) ** 2, axis=1))
[docs]
def distance_pathwise(
startpoint: "View", endpoints: Union["Branch", "Cell", "View"]
) -> List[float]:
"""Returns the pathwise distance between a root and other compartments.
We use Dijkstra's algorithm to get the path with the lowest
number of compartments between start and endpoint. It then computes the
length of that path in micrometers. Note that, for an uncyclic graph, the
path with the lowest number of compartments between start and endpoint is
also the path with the lowest length.
Args:
startpoint: A single compartment from which to compute the distance.
endpoints: One or multiple compartments to which to compute the distance to.
Returns:
A list of distances.
Example usage
^^^^^^^^^^^^^
Example 1: The following computes the pathwise distance between the zero-eth soma
compartment and all other compartments. It then saves this distance in
`cell.nodes["path_dist_from_soma"]`.
::
from jaxley.morphology import distance_pathwise
path_dists = distance_pathwise(cell.soma.branch(0).comp(0), cell)
cell.nodes["path_dist_from_soma"] = path_dists
Example 2: The following computes the pathwise distance between two compartments.
::
dist = distance_pathwise(cell.branch(8).comp(2), cell.branch(2).comp(0))
"""
assert len(startpoint.nodes.index) == 1, "Cannot use multiple root nodes."
root = startpoint.nodes.index[0]
endpoint_inds = endpoints.nodes.index
graph = to_graph(startpoint.base)
graph = nx.to_undirected(graph)
# Set default for branchpoints.
for _, data in graph.nodes(data=True):
data.setdefault("length", 0.0)
def edge_weight(u: int, v: int, d: Dict) -> float:
"""
Args:
u: Start node of an edge.
v: End node of an edge.
d: Dictionary of edge attributes (unused because all our attributes are
not attributes, not edge attributes).
Returns:
The pathwise distance between two nodes.
"""
return (graph.nodes[u]["length"] + graph.nodes[v]["length"]) / 2
return [
nx.dijkstra_path_length(graph, root, end, edge_weight) for end in endpoint_inds
]