Source code for jaxley.synapses.spike

# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>


from jax import Array

from jaxley.synapses.synapse import Synapse


[docs] class SpikeSynapse(Synapse): r"""Synapse to be used in networks of LIF neurons. This synapse is meant to be used in networks of LIF neurons, together with the `Fire` channel. After the pre-synaptic neuron has `Fire_spikes=1.0` (a spike), the state of the synapse gets increased immediately and then decays exponentially. This synapse implements the following equations: .. math:: I = \overline{g}\, \cdot s .. math:: \tau \frac{\text{d}s}{\text{d}t} = -s .. math:: s \leftarrow s + 1 \quad \text{if } \, \text{Fire}_{\text{pre}} = 1 The synaptic parameters are: - ``gS``: the maximal conductance :math:`\overline{g}` (uS). - ``decay_tau``: The time constant of the decay :math:`\tau` (ms). The synaptic state is: - ``s``: the activity level of the synapse. Example usage ^^^^^^^^^^^^^ .. code-block:: python from jaxley.channels import Leak, Fire from jaxley.connect import fully_connect from jaxley.synapse import SpikeSynapse cell = jx.Cell() net = jx.Network([cell for _ in range(5)]) net.insert(Leak()) net.insert(Fire()) fully_connect(net.cell("all"), net.cell("all"), SpikeSynapse()) net.record("v") v = jx.integrate(net, t_max=100.0, delta_t=0.025) """ def __init__(self, name: str | None = None): super().__init__(name) prefix = self._name self.synapse_params = { f"{prefix}_gS": 1e-4, # uS f"{prefix}_s_decay": 10.0, # unitless } self.synapse_states = {f"{prefix}_s": 0.0} self.node_parmas = {} self.node_states = {"Fire_spikes": 0.0}
[docs] def update_states( self, synapse_states: dict[str, Array], synapse_params: dict[str, Array], pre_voltage: Array, post_voltage: Array, pre_states: dict[str, Array], post_states: dict[str, Array], pre_params: dict[str, Array], post_params: dict[str, Array], delta_t: float, ) -> dict: """Return updated synapse state and current.""" prefix = self._name s = synapse_states[f"{prefix}_s"] s_decay = synapse_params[f"{prefix}_s_decay"] spike = pre_states["Fire_spikes"] new_s = (spike * 1.0) + ((1 - spike) * (s - (s_decay * s * delta_t))) return {f"{prefix}_s": new_s}
[docs] def compute_current( self, synapse_states: dict[str, Array], synapse_params: dict[str, Array], pre_voltage: Array, post_voltage: Array, pre_states: dict[str, Array], post_states: dict[str, Array], pre_params: dict[str, Array], post_params: dict[str, Array], delta_t: float, ) -> float: """Return updated synapse state and current.""" prefix = self._name g_syn = synapse_params[f"{prefix}_gS"] * synapse_states[f"{prefix}_s"] return g_syn