Source code for jaxley.channels.non_capacitive.spike

# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from typing import Optional

import jax.numpy as jnp
from jax import Array

from jaxley.channels import Channel
from jaxley.solver_gate import heaviside


[docs] class Fire(Channel): """Mechanism to reset the voltage when it crosses a threshold. When combined with a ``Leak`` channel, this can be used to implement leaky-integrate-and-fire neurons. Note that, after the voltage is reset by this channel, other channels (or external currents), can still modify the membrane voltage `within the same time step`. Note as well that this function implements a surrogate gradient through the use of the ``heaviside`` function in ``update_states()``. This allows the user to perform gradient descent on networks using this channel despite the ``Fire`` mechanism being non-differentiable. The following parameters are registered in ``channel_params``: .. list-table:: :widths: 25 15 50 10 :header-rows: 1 * - Name - Default - Description - Unit * - ``Fire_vth`` - -50.0 - Threshold for firing. - mV * - ``Fire_vreset`` - -70.0 - The reset for the voltage after a spike. - mV The following states are registered in ``channel_states``: .. list-table:: :widths: 25 15 50 10 :header-rows: 1 * - Name - Default - Description - Unit * - ``Fire_spikes`` - False - Whether or not a spike occured. - 1 """ def __init__(self, name: Optional[str] = None): super().__init__(name) self.channel_params = {f"{self.name}_vth": -50, f"{self.name}_vreset": -70} self.channel_states = {f"{self.name}_spikes": False} self.current_name = f"{self.name}_fire"
[docs] def update_states( self, states: dict[str, Array], params: dict[str, Array], voltage: Array, delta_t: float, ): """Reset the voltage when a spike occurs and log the spike""" prefix = self._name vreset = params[f"{prefix}_vreset"] vth = params[f"{prefix}_vth"] spike_occurred = heaviside(voltage - vth) voltage = (voltage * (1 - heaviside(voltage - vth))) + ( vreset * heaviside(voltage - vth) ) return {"v": voltage, f"{self.name}_spikes": spike_occurred}
[docs] def compute_current( self, states: dict[str, Array], params: dict[str, Array], voltage: Array, delta_t: float, ): return 0
[docs] def init_state( self, states: dict[str, Array], params: dict[str, Array], voltage: Array, delta_t: float, ): return {}