Source code for jaxley.channels.hh

# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from typing import Optional

import jax.numpy as jnp
from jax import Array
from jax.typing import ArrayLike

from jaxley.channels import Channel
from jaxley.solver_gate import save_exp, solve_gate_exponential


[docs] class HH(Channel): """Hodgkin-Huxley channel based on Sterratt, Graham, Gillies & Einevoll.""" def __init__(self, name: Optional[str] = None): self.current_is_in_mA_per_cm2 = True super().__init__(name) prefix = self._name self.channel_params = { f"{prefix}_gNa": 0.12, f"{prefix}_gK": 0.036, f"{prefix}_gLeak": 0.0003, f"{prefix}_eNa": 50.0, f"{prefix}_eK": -77.0, f"{prefix}_eLeak": -54.3, } self.channel_states = { f"{prefix}_m": 0.2, f"{prefix}_h": 0.2, f"{prefix}_n": 0.2, } self.current_name = f"i_HH"
[docs] def update_states( self, states: dict[str, Array], dt: float, v: float, params: dict[str, Array], ) -> dict[str, Array]: """Return updated HH channel state.""" prefix = self._name m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"] new_m = solve_gate_exponential(m, dt, *self.m_gate(v)) new_h = solve_gate_exponential(h, dt, *self.h_gate(v)) new_n = solve_gate_exponential(n, dt, *self.n_gate(v)) return {f"{prefix}_m": new_m, f"{prefix}_h": new_h, f"{prefix}_n": new_n}
[docs] def compute_current( self, states: dict[str, Array], v: float, params: dict[str, Array] ) -> float: """Return current through HH channels.""" prefix = self._name m, h, n = states[f"{prefix}_m"], states[f"{prefix}_h"], states[f"{prefix}_n"] gNa = params[f"{prefix}_gNa"] * (m**3) * h # S/cm^2 gK = params[f"{prefix}_gK"] * n**4 # S/cm^2 gLeak = params[f"{prefix}_gLeak"] # S/cm^2 return ( gNa * (v - params[f"{prefix}_eNa"]) + gK * (v - params[f"{prefix}_eK"]) + gLeak * (v - params[f"{prefix}_eLeak"]) )
[docs] def init_state( self, states: dict[str, ArrayLike], v: ArrayLike, params: dict[str, ArrayLike], delta_t: float, ) -> dict[str, float]: """Initialize the state such at fixed point of gate dynamics.""" prefix = self._name alpha_m, beta_m = self.m_gate(v) alpha_h, beta_h = self.h_gate(v) alpha_n, beta_n = self.n_gate(v) return { f"{prefix}_m": alpha_m / (alpha_m + beta_m), f"{prefix}_h": alpha_h / (alpha_h + beta_h), f"{prefix}_n": alpha_n / (alpha_n + beta_n), }
[docs] @staticmethod def m_gate(v): alpha = 0.1 * _vtrap(-(v + 40), 10) beta = 4.0 * save_exp(-(v + 65) / 18) return alpha, beta
[docs] @staticmethod def h_gate(v): alpha = 0.07 * save_exp(-(v + 65) / 20) beta = 1.0 / (save_exp(-(v + 35) / 10) + 1) return alpha, beta
[docs] @staticmethod def n_gate(v): alpha = 0.01 * _vtrap(-(v + 55), 10) beta = 0.125 * save_exp(-(v + 65) / 80) return alpha, beta
def _vtrap(x, y): return x / (save_exp(x / y) - 1.0)