# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import numpy as np
def is_same_network(pre: "View", post: "View") -> bool:
"""Check if views are from the same network."""
is_in_net = "network" in pre.base.__class__.__name__.lower()
is_in_same_net = pre.base is post.base
return is_in_net and is_in_same_net
def sample_comp(cell_view: "View", num: int = 1, replace=True) -> "CompartmentView":
"""Sample a compartment from a cell.
Returns View with shape (num, num_cols)."""
return np.random.choice(cell_view._comps_in_view, num, replace=replace)
[docs]
def connect(
pre: "View",
post: "View",
synapse_type: "Synapse",
):
"""Connect specific compartments of a network with a synapse.
The pre- and postsynaptic compartments must be compartments of the
same network. If `pre` and `post` are both just a single compartment, then this
function instantiates a single synapse. If `pre` and `post` are both `N`
compartments, then this function instatiates `N` synapses (first-to-first,
second-to-second,...).
Args:
pre: View of the presynaptic compartment.
post: View of the postsynaptic compartment.
synapse_type: The type of synapse to use.
Example usage
^^^^^^^^^^^^^
Example 1: Connect one compartment to another compartment with a single synapse:
::
from jaxley.connect import connect
from jaxley.synapses import IonotropicSynapse
net = jx.Network([cell for _ in range(10)])
connect(
net.cell(0).branch(0).comp(0),
net.cell(1).branch(0).comp(0),
IonotropicSynapse(),
)
print(net.edges)
Example 2: Connect `N` compartments to `N` other compartments with `N` synapses:
::
from jaxley.connect import connect
from jaxley.synapses import IonotropicSynapse
net = jx.Network([cell for _ in range(10)])
connect(
net.cell(0).branch([0, 1]).comp(0),
net.cell(1).branch([2, 3]).comp(0),
IonotropicSynapse(),
)
print(net.edges)
"""
assert is_same_network(
pre, post
), "Pre and post compartments must be part of the same network."
pre.base._append_multiple_synapses(pre.nodes, post.nodes, synapse_type)
[docs]
def fully_connect(
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
random_post_comp: bool = False,
):
"""Fully (densely) connect cells of a network with synapses.
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0
location 0 of the post-synaptic cell unless random_post_comp=True.
Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
random_post_comp: If True, randomly samples the postsynaptic compartments.
Example usage
^^^^^^^^^^^^^
The following example insert 12 synapses (3 x 4).
::
from jaxley.connect import fully_connect
net = jx.Network([cell for _ in range(10)])
fully_connect(
net.cell([0, 1, 2]),
net.cell([3, 4, 5, 6]),
IonotropicSynapse(),
)
print(net.edges)
"""
# Get pre- and postsynaptic cell indices.
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)
# Pre-synapse is at the zero-eth branch and zero-eth compartment.
pre_rows = pre_cell_view.scope("local").branch(0).comp(0).nodes.copy()
# Repeat rows `num_post` times. See SO 50788508.
pre_rows = pre_rows.loc[pre_rows.index.repeat(num_post)]
if random_post_comp:
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index")
.sample(num_pre, replace=True)
.index.to_numpy()
)
# Reorder the post comp inds to tile order (pre indices are repeated so here tile needed)
global_post_comp_indices = np.reshape(
global_post_comp_indices, (num_pre, num_post)
).T.flatten()
else:
# Post-synapse also at the zero-eth branch and zero-eth compartment
post_cell_view.nodes["orig_index"] = post_cell_view.nodes.index
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index").first()["orig_index"]
).to_numpy()
to_idx = np.tile(range(0, num_post), num_pre)
global_post_comp_indices = global_post_comp_indices[to_idx]
post_cell_view.nodes.drop(columns="orig_index", inplace=True)
post_rows = post_cell_view.nodes.loc[global_post_comp_indices]
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
[docs]
def sparse_connect(
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
p: float,
random_post_comp: bool = False,
):
"""Sparsely connect cells of a network with synapses.
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0
location 0 of the post-synaptic cell unless random_post_comp=True.
NOTE: This function does not generate sparse random connectivity with random graph
generation methodology. Cells may be connected multiple times and p=1.0 does
not fully connect.
Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
p: Probability of connection.
random_post_comp: If True, randomly samples the postsynaptic compartments.
Example usage
^^^^^^^^^^^^^
The following example insert approximately 6 synapses (3 x 4 = 12 possible
synapses, with connection probability 0.5).
::
from jaxley.connect import sparse_connect
from jaxley.synapses import IonotropicSynapse
net = jx.Network([cell for _ in range(10)])
sparse_connect(
net.cell([0, 1, 2]),
net.cell([3, 4, 5, 6]),
IonotropicSynapse(),
p=0.5,
)
print(net.edges)
"""
# Get pre- and postsynaptic cell indices.
pre_cell_inds = pre_cell_view._cells_in_view
post_cell_inds = post_cell_view._cells_in_view
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)
num_connections = np.random.binomial(num_pre * num_post, p)
if num_connections == 0:
# Don't do any of the following connections if no synapse is inserted.
return
pre_syn_neurons = np.random.choice(pre_cell_inds, size=num_connections)
post_syn_neurons = np.random.choice(post_cell_inds, size=num_connections)
# Sort the synapses only for convenience of inspecting `.edges`.
sorting = np.argsort(pre_syn_neurons)
pre_syn_neurons = pre_syn_neurons[sorting]
post_syn_neurons = post_syn_neurons[sorting]
pre_syn_neurons, inverse_pre = np.unique(pre_syn_neurons, return_inverse=True)
# Pre-synapse is at the zero-eth branch and zero-eth compartment.
global_pre_indices = (
pre_cell_view.scope("global")
.cell(pre_syn_neurons)
.scope("local")
.branch(0)
.comp(0)
.nodes.index
)
global_pre_indices = global_pre_indices[inverse_pre]
pre_rows = pre_cell_view.base.nodes.loc[global_pre_indices]
# Sample the post-synaptic compartments
if random_post_comp:
# Filter the post cell view to include post-synaptic neurons
post_syn_view = post_cell_view.nodes[
post_cell_view.nodes["global_cell_index"].isin(post_syn_neurons)
].copy()
# Determine how many comps to sample for each post-synaptic neuron
unique_cells, counts = np.unique(post_syn_neurons, return_counts=True)
n_samples_dict = dict(zip(unique_cells, counts))
post_syn_view["orig_index"] = post_syn_view.index
sampled_inds = post_syn_view.groupby("global_cell_index").apply(
lambda x: x.sample(n=n_samples_dict[x.name], replace=True)
)
global_post_comp_indices = sampled_inds.orig_index.to_numpy()
post_rows = post_cell_view.nodes.loc[global_post_comp_indices]
post_syn_view.drop(columns="orig_index", inplace=True)
else:
post_syn_neurons, inverse_post = np.unique(
post_syn_neurons, return_inverse=True
)
# Post-synapse also at the zero-eth branch and zero-eth compartment
global_post_indices = (
post_cell_view.scope("global")
.cell(post_syn_neurons)
.scope("local")
.branch(0)
.comp(0)
.nodes.index
)
global_post_indices = global_post_indices[inverse_post]
post_rows = post_cell_view.base.nodes.loc[global_post_indices]
if len(pre_rows) > 0:
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)
[docs]
def connectivity_matrix_connect(
pre_cell_view: "View",
post_cell_view: "View",
synapse_type: "Synapse",
connectivity_matrix: np.ndarray[bool],
random_post_comp: bool = False,
):
"""Connect cells of a network with synapses via a boolean connectivity matrix.
Entries > 0 in the matrix indicate a connection between the corresponding cells.
Connections are from branch 0 location 0 of the pre-synaptic cell to branch 0
location 0 of the post-synaptic cell unless random_post_comp=True.
Args:
pre_cell_view: View of the presynaptic cell.
post_cell_view: View of the postsynaptic cell.
synapse_type: The synapse to append.
connectivity_matrix: A boolean matrix indicating the connections between cells.
If floating point values are passed, they are _not_ interpreted as
synaptic weights, but we only check if they are zero (no connection) or
not (connection).
random_post_comp: If True, randomly samples the postsynaptic compartments.
Example usage
^^^^^^^^^^^^^
The following generates a random 10 x 10 boolean matrix and uses it to connect the
neurons in a network.
::
from jaxley.connect import connectivity_matrix_connect
from jaxley.synapses import IonotropicSynapse
net = jx.Network([cell for _ in range(10)])
connectivity_matrix = np.random.choice([False, True], size=(10, 10))
connectivity_matrix_connect(
net.cell("all"),
net.cell("all"),
IonotropicSynapse(),
connectivity_matrix,
)
print(net.edges)
"""
# Get pre- and postsynaptic cell indices
num_pre = len(pre_cell_view._cells_in_view)
num_post = len(post_cell_view._cells_in_view)
assert connectivity_matrix.shape == (
num_pre,
num_post,
), "Connectivity matrix must have shape (num_pre, num_post)."
assert connectivity_matrix.dtype == bool, "Connectivity matrix must be boolean."
# Get pre to post connection pairs from connectivity matrix
from_idx, to_idx = np.where(connectivity_matrix)
# Pre-synapse at the zero-eth branch and zero-eth compartment
pre_cell_view.nodes["orig_index"] = pre_cell_view.nodes.index
post_cell_view.nodes["orig_index"] = post_cell_view.nodes.index
global_pre_comp_indices = (
pre_cell_view.nodes.groupby("global_cell_index").first()["orig_index"]
).to_numpy()
pre_rows = pre_cell_view.select(nodes=global_pre_comp_indices[from_idx]).nodes
pre_cell_view.nodes.drop(columns="orig_index", inplace=True)
if random_post_comp:
global_to_idx = post_cell_view.nodes.global_cell_index.unique()[to_idx]
# Filter the post cell view to include post-synaptic neurons selected
post_syn_view = post_cell_view.nodes[
post_cell_view.nodes["global_cell_index"].isin(global_to_idx)
]
# Determine how many comps to sample for each post-synaptic neuron
unique_cells, counts = np.unique(global_to_idx, return_counts=True)
# Sample the post-synaptic compartments
n_samples_dict = dict(zip(unique_cells, counts))
sampled_inds = post_syn_view.groupby("global_cell_index").apply(
lambda x: x.sample(n=n_samples_dict[x.name], replace=True)
)
global_post_comp_indices = sampled_inds.orig_index.to_numpy()
else:
# Post-synapse also at the zero-eth branch and zero-eth compartment
global_post_comp_indices = (
post_cell_view.nodes.groupby("global_cell_index").first()["orig_index"]
).to_numpy()
global_post_comp_indices = global_post_comp_indices[to_idx]
post_cell_view.nodes.drop(columns="orig_index", inplace=True)
post_rows = post_cell_view.select(nodes=global_post_comp_indices).nodes
pre_cell_view.base._append_multiple_synapses(pre_rows, post_rows, synapse_type)