Source code for jaxley.synapses.alpha_synapse

# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>


from jax import Array

from jaxley.solver_gate import exponential_euler
from jaxley.synapses.synapse import Synapse


[docs] class AlphaSynapse(Synapse): r"""Alpha synapse which responds to binary pre-synaptic spike trains. This synapse is meant to be used together with pre-synaptic neurons whose voltage is binary (indicating spike or no spike). This synapse is implemented as two cascaded first-order linear ODEs: .. math:: \tau_{\mathrm{rise}}\frac{d r(t)}{d t} = -r(t) + x(t) .. math:: \tau_{\mathrm{decay}}\frac{d s(t)}{d t} = -s(t) + r(t) .. math:: I = \overline{g}\, \cdot s \cdot (E - V_{\text{post}}) Here, :math:`x(t)` denotes the presynaptic input (typically a binary spike train), :math:`r(t)` is an intermediate *rise* state, and :math:`s(t)` is the synaptic state that determines the synaptic conductance. For an impulse input :math:`x(t) = \delta(t)`, the resulting synaptic kernel is .. math:: s(t) \propto e^{-t / \tau_{\mathrm{decay}}} - e^{-t / \tau_{\mathrm{rise}}}, \qquad t \ge 0. The synaptic parameters are: - ``gS``: the maximal conductance :math:`\overline{g}` (uS). - ``tau_decay``: The decay time constant :math:`\tau_{\text{rise}}` (ms). - ``tau_rise``: The rise time constant :math:`\tau_{\text{decay}}` (ms). The inserted cellular parameters are: - ``e_syn``: The synaptic reversal potential :math:`E` (mV). This synapse uses the pre-synaptic reveral potential to compute the current, thereby directly enforcing Dale's law. The synaptic state is: - ``r``: Intermediate state representing the rising phase. - ``s``: Activity level of the synapse. Example usage ^^^^^^^^^^^^^ .. code-block:: python import jaxley as jx from jaxley.connect import connect from jaxley.synapses import AlphaSynapse from jaxley.channels import Leak dummy = jx.Cell() cell = jx.read_swc("morph_ca1_n120.swc", ncomp=1) net = jx.Network([dummy, cell]) net.cell(1).insert(Leak()) # Connect pre-synaptic dummy to the morphologically detailed cell. connect(net.cell(0), net.cell(1).branch(5).comp(0), AlphaSynapse()) net.set("AlphaSynapse_gS", 0.1) # Synaptic strength. net.set("AlphaSynapse_tau_decay", 5.0) # Decay time in ms # Clamp the voltage of the pre-synaptic cell to the spike train. net.cell(0).set("v", 0.0) # Initial state. net.cell(0).clamp("v", spike_train) net.cell(1).branch(5).comp(0).record() v = jx.integrate(net, delta_t=dt) """ def __init__(self, name: str | None = None): super().__init__(name) prefix = self._name self.synapse_params = { f"{prefix}_gS": 1e-4, # uS f"{prefix}_tau_rise": 10.0, # ms f"{prefix}_tau_decay": 10.0, # ms } self.synapse_states = { f"{prefix}_r": 0.0, f"{prefix}_s": 0.0, } self.node_params = { f"{prefix}_e_syn": 0.0, # mV, } self.node_states = {}
[docs] def update_states( self, synapse_states: dict[str, Array], synapse_params: dict[str, Array], pre_voltage: Array, post_voltage: Array, pre_states: dict[str, Array], post_states: dict[str, Array], pre_params: dict[str, Array], post_params: dict[str, Array], delta_t: float, ) -> dict: """Return updated synapse state and current.""" prefix = self._name r = synapse_states[f"{prefix}_r"] s = synapse_states[f"{prefix}_s"] tau_rise = synapse_params[f"{prefix}_tau_rise"] tau_decay = synapse_params[f"{prefix}_tau_decay"] r_inf = 1 / delta_t * pre_voltage r_new = exponential_euler(r, delta_t, r_inf, tau_rise) s_new = exponential_euler(s, delta_t, r, tau_decay) return {f"{prefix}_s": s_new, f"{prefix}_r": r_new}
[docs] def compute_current( self, synapse_states: dict[str, Array], synapse_params: dict[str, Array], pre_voltage: Array, post_voltage: Array, pre_states: dict[str, Array], post_states: dict[str, Array], pre_params: dict[str, Array], post_params: dict[str, Array], delta_t: float, ) -> float: """Return updated synapse state and current.""" prefix = self._name g_syn = synapse_params[f"{prefix}_gS"] * synapse_states[f"{prefix}_s"] return g_syn * (post_voltage - pre_params[f"{prefix}_e_syn"])