# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Optional
import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike
from jaxley.channels import Channel
from jaxley.solver_gate import save_exp, solve_gate_exponential
[docs]
class HH(Channel):
"""Hodgkin-Huxley channel based on Sterratt, Graham, Gillies & Einevoll."""
def __init__(self, name: Optional[str] = None):
self.current_is_in_mA_per_cm2 = True
super().__init__(name)
prefix = self._name
self.channel_params = {
f"{prefix}_gNa": 0.12,
f"{prefix}_gK": 0.036,
f"{prefix}_gLeak": 0.0003,
f"{prefix}_eNa": 50.0,
f"{prefix}_eK": -77.0,
f"{prefix}_eLeak": -54.3,
}
self.channel_states = {
f"{prefix}_m": 0.2,
f"{prefix}_h": 0.2,
f"{prefix}_n": 0.2,
}
self.current_name = f"i_HH"
[docs]
def update_states(
self,
states: dict[str, Array],
dt: float,
v: float,
params: dict[str, Array],
) -> dict[str, Array]:
"""Return updated HH channel state."""
prefix = self._name
m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"]
new_m = solve_gate_exponential(m, dt, *self.m_gate(v))
new_h = solve_gate_exponential(h, dt, *self.h_gate(v))
new_n = solve_gate_exponential(n, dt, *self.n_gate(v))
return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n}
[docs]
def compute_current(
self, states: dict[str, Array], v: float, params: dict[str, Array]
) -> float:
"""Return current through HH channels."""
prefix = self._name
m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"]
gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2
gK = params[f"{prefix}_gK"] * n**4 # S/cm^2
gLeak = params[f"{prefix}_gLeak"] # S/cm^2
return (
gNa * (v - params[f"{prefix}_eNa"])
+ gK * (v - params[f"{prefix}_eK"])
+ gLeak * (v - params[f"{prefix}_eLeak"])
)
[docs]
def init_state(
self,
states: dict[str, ArrayLike],
v: ArrayLike,
params: dict[str, ArrayLike],
delta_t: float,
) -> dict[str, float]:
"""Initialize the state such at fixed point of gate dynamics."""
prefix = self._name
alpha_m, beta_m = self.m_gate(v)
alpha_h, beta_h = self.h_gate(v)
alpha_n, beta_n = self.n_gate(v)
return {
f"{prefix}_m": alpha_m / (alpha_m + beta_m),
f"{prefix}_h": alpha_h / (alpha_h + beta_h),
f"{prefix}_n": alpha_n / (alpha_n + beta_n),
}
[docs]
@staticmethod
def m_gate(v):
alpha = 0.1 * _vtrap(-(v + 40), 10)
beta = 4.0 * save_exp(-(v + 65) / 18)
return alpha, beta
[docs]
@staticmethod
def h_gate(v):
alpha = 0.07 * save_exp(-(v + 65) / 20)
beta = 1.0 / (save_exp(-(v + 35) / 10) + 1)
return alpha, beta
[docs]
@staticmethod
def n_gate(v):
alpha = 0.01 * _vtrap(-(v + 55), 10)
beta = 0.125 * save_exp(-(v + 65) / 80)
return alpha, beta
def _vtrap(x, y):
return x / (save_exp(x / y) - 1.0)