Source code for jaxley.utils.syn_utils
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Tuple
import jax.numpy as jnp
import numpy as np
from jax.lax import ScatterDimensionNumbers, scatter_add
[docs]
def gather_synapes(
number_of_compartments: jnp.ndarray,
post_syn_comp_inds: np.ndarray,
current_each_synapse_voltage_term: jnp.ndarray,
current_each_synapse_constant_term: jnp.ndarray,
) -> Tuple[jnp.ndarray, jnp.ndarray]:
"""Compute current at the post synapse.
All this does it that it sums the synaptic currents that come into a particular
compartment. It returns an array of as many elements as there are compartments.
"""
incoming_currents_voltages = jnp.zeros((number_of_compartments,))
incoming_currents_contant = jnp.zeros((number_of_compartments,))
dnums = ScatterDimensionNumbers(
update_window_dims=(),
inserted_window_dims=(0,),
scatter_dims_to_operand_dims=(0,),
)
incoming_currents_voltages = scatter_add(
incoming_currents_voltages,
post_syn_comp_inds[:, None],
current_each_synapse_voltage_term,
dnums,
)
incoming_currents_contant = scatter_add(
incoming_currents_contant,
post_syn_comp_inds[:, None],
current_each_synapse_constant_term,
dnums,
)
return incoming_currents_voltages, incoming_currents_contant