Source code for jaxley.channels.channel

# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple
from warnings import warn

from jax import Array
from jax.typing import ArrayLike


[docs] class Channel: """Channel base class. All channels inherit from this class. A channel in Jaxley is everything that modifies the membrane voltage via its current returned by the `compute_current()` method. As in NEURON, a `Channel` is considered a distributed process, which means that its conductances are to be specified in `S/cm2` and its currents are to be specified in `uA/cm2`.""" _name = None channel_params = None channel_states = None current_name = None def __init__(self, name: Optional[str] = None): contact = ( "If you have any questions, please reach out via email to " "michael.deistler@uni-tuebingen.de or create an issue on Github: " "https://github.com/jaxleyverse/jaxley/issues. Thank you!" ) self._name = name if name else self.__class__.__name__ @property def name(self) -> Optional[str]: """The name of the channel (by default, this is the class name).""" return self._name
[docs] def change_name(self, new_name: str): """Change the channel name. Args: new_name: The new name of the channel. Returns: Renamed channel, such that this function is chainable. """ old_prefix = self._name + "_" new_prefix = new_name + "_" self._name = new_name self.channel_params = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value for key, value in self.channel_params.items() } self.channel_states = { ( new_prefix + key[len(old_prefix) :] if key.startswith(old_prefix) else key ): value for key, value in self.channel_states.items() } return self
[docs] def update_states( self, states: dict[str, Array], params: dict[str, Array], voltage: Array, delta_t: float, ) -> dict[str, Array]: """Return the updated states. Args: states: All states of the compartment. params: Parameters of the channel (conductances in `S/cm2`). voltage: Voltage of the compartment in mV. delta_t: The time step in ms. Returns: A dictionary of updated state values. """ raise NotImplementedError
[docs] def compute_current( self, states: dict[str, Array], params: dict[str, Array], voltage: Array, delta_t: float, ) -> Array: """Given channel states and voltage, return the current through the channel. Args: states: All states of the compartment. params: Parameters of the channel (conductances in `S/cm2`). voltage: Voltage of the compartment in mV. delta_t: The time step in ms. Returns: Current in `uA/cm2`. """ raise NotImplementedError
[docs] def init_state( self, states: dict[str, Array], params: dict[str, Array], voltage: Array, delta_t: float, ) -> dict[str, Array]: """Initialize states of channel. Args: states: All states of the compartment. params: Parameters of the channel (conductances in `S/cm2`). voltage: Voltage of the compartment in mV. delta_t: The time step in ms. Returns: A initial state that is written into ``module.nodes`` when the user runs ``module.init_states()``. """ return {}
[docs] def init_params( self, states: dict[str, Array], params: dict[str, Array], voltage: Array, delta_t: float, ) -> dict[str, Array]: """Initialize the maximal conductances given the temperature. Args: states: All states of the compartment. params: Parameters of the channel (conductances in `S/cm2`). voltage: Voltage of the compartment in mV. delta_t: The time step in ms. Returns: Initial parameters that are written into ``module.nodes`` when the user runs ``module.init_params()``. """ return {}