Source code for jaxley.channels.non_capacitive.izhikevich
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Optional
from warnings import warn
import jax
import jax.numpy as jnp
from jax import Array
from jaxley.channels import Channel
from jaxley.solver_gate import exponential_euler
[docs]
class Izhikevich(Channel):
"""Izhikevich neuron model.
The following parameters are registered in ``channel_params``:
.. list-table::
:widths: 25 15 50 10
:header-rows: 1
* - Name
- Default
- Description
- Unit
* - ``a``
- 0.02
- Time scale of the recovery variable ``u``.
- 1/ms
* - ``b``
- 0.2
- Sensitivity of the recovery variable ``u`` to the membrane potential ``v``.
- 1/ms
* - ``c``
- -65.0
- After-spike reset value of the membrane potential ``v``.
- mV
* - ``d``
- 8
- After-spike increment of the recovery variable ``u``.
- mV/ms
The following states are registered in ``channel_states``:
.. list-table::
:widths: 25 15 50 10
:header-rows: 1
* - Name
- Default
- Description
- Unit
* - ``u``
- 0.02
- Recovery variable.
- mV/ms
"""
def __init__(self, name: Optional[str] = None):
super().__init__(name)
self.channel_params = {
f"{self.name}_a": 0.02,
f"{self.name}_b": 0.2,
f"{self.name}_c": -65.0,
f"{self.name}_d": 8,
}
self.channel_states = {f"{self.name}_u": 0.0}
self.current_name = f"{self.name}_izhikevich"
warn(
"The `Izhikevich` channel does not support surrogate gradients. Its "
"gradient will be zero after every spike."
)
[docs]
def update_states(
self,
states: dict[str, Array],
params: dict[str, Array],
voltage: Array,
delta_t: float,
):
"""Reset the voltage when a spike occurs and log the spike"""
a = params[f"{self.name}_a"]
b = params[f"{self.name}_b"]
c = params[f"{self.name}_c"]
d = params[f"{self.name}_d"]
u = states[f"{self.name}_u"]
# Update the recovery variable u with exponential Euler.
u = exponential_euler(u, delta_t, b * voltage, 1 / a)
# Update voltages with Forward Euler because the vectorfield is nonlinear in v.
dv = (0.04 * voltage**2) + (5 * voltage) + 140 - u
voltage = voltage + delta_t * dv
condition = voltage >= 30.0
voltage = jax.lax.select(condition, c, voltage)
u = jax.lax.select(condition, u + d, u)
return {f"{self.name}_u": u, "v": voltage}
[docs]
def compute_current(
self,
states: dict[str, Array],
params: dict[str, Array],
voltage: Array,
delta_t: float,
):
return 0
[docs]
def init_state(
self,
states: dict[str, Array],
params: dict[str, Array],
voltage: Array,
delta_t: float,
):
prefix = self.name
return {f"{self.name}_u": params[f"{prefix}_b"] * voltage}