Source code for jaxley.channels.non_capacitive.spike
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Optional
from warnings import warn
import jax
import jax.numpy as jnp
from jaxley.channels import Channel
[docs]
class Fire(Channel):
"""Mechanism to reset the voltage when it crosses a threshold.
When combined with a ``Leak`` channel, this can be used to implement
leaky-integrate-and-fire neurons.
Note that, after the voltage is reset by this channel, other channels (or external
currents), can still modify the membrane voltage `within the same time step`.
"""
def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True
super().__init__(name)
self.channel_params = {f"{self.name}_vth": -50, f"{self.name}_vreset": -70}
self.channel_states = {f"{self.name}_spikes": False}
self.current_name = f"{self.name}_fire"
warn(
"The `Fire` channel does not support surrogate gradients. Its gradient "
"will be zero after every spike."
)
[docs]
def update_states(self, states, dt, v, params):
"""Reset the voltage when a spike occurs and log the spike"""
prefix = self._name
vreset = params[f"{prefix}_vreset"]
vth = params[f"{prefix}_vth"]
spike_occurred = v > vth
v = jax.lax.select(spike_occurred, vreset, v)
return {"v": v, f"{self.name}_spikes": spike_occurred}
[docs]
def compute_current(self, states, v, params):
return jnp.zeros((1,))
[docs]
def init_state(self, states, v, params, delta_t):
return {}