Source code for jaxley.channels.non_capacitive.izhikevich
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Optional
from warnings import warn
import jax
import jax.numpy as jnp
from jaxley.channels import Channel
from jaxley.solver_gate import exponential_euler
[docs]
class Izhikevich(Channel):
"""Izhikevich neuron model."""
def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True
super().__init__(name)
self.channel_params = {
f"{self.name}_a": 0.02,
f"{self.name}_b": 0.2,
f"{self.name}_c": -65.0,
f"{self.name}_d": 8,
}
self.channel_states = {f"{self.name}_u": 0.0}
self.current_name = f"{self.name}_izhikevich"
warn(
"The `Izhikevich` channel does not support surrogate gradients. Its "
"gradient will be zero after every spike."
)
[docs]
def update_states(self, states, dt, v, params):
"""Reset the voltage when a spike occurs and log the spike"""
a = params[f"{self.name}_a"]
b = params[f"{self.name}_b"]
c = params[f"{self.name}_c"]
d = params[f"{self.name}_d"]
u = states[f"{self.name}_u"]
# Update the recovery variable u with exponential Euler.
u = exponential_euler(u, dt, b * v, 1 / a)
# Update voltages with Forward Euler because the vectorfield is nonlinear in v.
dv = (0.04 * v**2) + (5 * v) + 140 - u
v = v + dt * dv
condition = v >= 30.0
v = jax.lax.select(condition, c, v)
u = jax.lax.select(condition, u + d, u)
return {f"{self.name}_u": u, "v": v}
[docs]
def compute_current(self, states, v, params):
return jnp.zeros((1,))
[docs]
def init_state(self, states, v, params, delta_t):
prefix = self.name
return {f"{self.name}_u": params[f"{prefix}_b"] * v}