# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import numpy as np
def is_same_network(pre: "View", post: "View") -> bool:
"""Check if views are from the same network."""
is_in_net = "network" in pre.base.__class__.__name__.lower()
is_in_same_net = pre.base is post.base
return is_in_net and is_in_same_net
def sample_comp(cell_view: "View", num: int = 1, replace=True) -> "CompartmentView":
"""Sample a compartment from a cell.
Returns View with shape (num, num_cols)."""
return np.random.choice(cell_view._comps_in_view, num, replace=replace)
[docs]
def connect(
pre: "View",
post: "View",
synapse_type: "Synapse",
):
"""Connect two compartments with a chemical synapse.
The pre- and postsynaptic compartments must be different compartments of the
same network.
Args:
pre: View of the presynaptic compartment.
post: View of the postsynaptic compartment.
synapse_type: The synapse to append
"""
assert is_same_network(
pre, post
), "Pre and post compartments must be part of the same network."
pre.base._append_multiple_synapses(pre.nodes, post.nodes, synapse_type)
[docs]
def fully_connect(
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
):
"""Appends multiple connections which build a fully connected layer.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
"""
# Get pre- and postsynaptic cell indices.
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)
# Infer indices of (random) postsynaptic compartments.
global_post_indices = (
post_cell_view.nodes.groupby("global_cell_index")
.sample(num_pre, replace=True)
.index.to_numpy()
)
global_post_indices = global_post_indices.reshape((-1, num_pre), order="F").ravel()
post_rows = post_cell_view.nodes.loc[global_post_indices]
# Pre-synapse is at the zero-eth branch and zero-eth compartment.
pre_rows = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy()
# Repeat rows `num_post` times. See SO 50788508.
pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)].reset_index(drop=True)
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
[docs]
def sparse_connect(
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
p: float,
):
"""Appends multiple connections which build a sparse, randomly connected layer.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
p: Probability of connection.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view
num_pre = len(pre_cell_inds)
num_post = len(post_cell_inds)
num_connections = np.random.binomial(num_pre * num_post, p)
pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)
post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)
# Sort the synapses only for convenience of inspecting `.edges`.
sorting = np.argsort(pre_syn_neurons)
pre_syn_neurons = pre_syn_neurons[sorting]
post_syn_neurons = post_syn_neurons[sorting]
# Post-synapse is a randomly chosen branch and compartment.
global_post_indices = [
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_syn_neurons
]
global_post_indices = (
np.hstack(global_post_indices) if len(global_post_indices) > 1 else []
)
post_rows = post_cell_view.base.nodes.loc[global_post_indices]
# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_view.base._cumsum_ncomp_per_cell[pre_syn_neurons]
pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]
if len(pre_rows) > 0:
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
[docs]
def connectivity_matrix_connect(
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
connectivity_matrix: np.ndarray[bool],
):
"""Appends multiple connections which build a custom connected network.
Connects pre- and postsynaptic cells according to a custom connectivity matrix.
Entries > 0 in the matrix indicate a connection between the corresponding cells.
Connections are from branch 0 location 0 to a randomly chosen branch and loc.
Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
connectivity_matrix: A boolean matrix indicating the connections between cells.
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view
# setting scope ensure that this works indep of current scope
pre_nodes = pre_cell_view.scope("local").branch(0).comp(0).nodes
pre_nodes["index"] = pre_nodes.index
pre_cell_nodes = pre_nodes.set_index("global_cell_index")
assert connectivity_matrix.shape == (
len(pre_cell_inds),
len(post_cell_inds),
), "Connectivity matrix must have shape (num_pre, num_post)."
assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean."
# get connection pairs from connectivity matrix
from_idx, to_idx = np.where(connectivity_matrix)
pre_cell_inds = pre_cell_inds[from_idx]
post_cell_inds = post_cell_inds[to_idx]
# Sample random postsynaptic compartments (global comp indices).
global_post_indices = np.hstack(
[
sample_comp(post_cell_view.scope("global").cell(cell_idx))
for cell_idx in post_cell_inds
]
)
post_rows = post_cell_view.nodes.loc[global_post_indices]
# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = pre_cell_nodes.loc[pre_cell_inds, "index"].to_numpy()
pre_rows = pre_cell_view.select(nodes=global_pre_indices).nodes
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)