# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from abc import ABC, abstractmethod
from typing import Dict, Optional, Tuple
from warnings import warn
from jax import Array
from jax.typing import ArrayLike
[docs]
class Channel:
"""Channel base class. All channels inherit from this class.
A channel in Jaxley is everything that modifies the membrane voltage via its
current returned by the `compute_current()` method.
As in NEURON, a `Channel` is considered a distributed process, which means that its
conductances are to be specified in `S/cm2` and its currents are to be specified in
`uA/cm2`."""
_name = None
channel_params = None
channel_states = None
current_name = None
def __init__(self, name: Optional[str] = None):
contact = (
"If you have any questions, please reach out via email to "
"michael.deistler@uni-tuebingen.de or create an issue on Github: "
"https://github.com/jaxleyverse/jaxley/issues. Thank you!"
)
self._name = name if name else self.__class__.__name__
@property
def name(self) -> Optional[str]:
"""The name of the channel (by default, this is the class name)."""
return self._name
[docs]
def change_name(self, new_name: str):
"""Change the channel name.
Args:
new_name: The new name of the channel.
Returns:
Renamed channel, such that this function is chainable.
"""
old_prefix = self._name + "_"
new_prefix = new_name + "_"
self._name = new_name
self.channel_params = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_params.items()
}
self.channel_states = {
(
new_prefix + key[len(old_prefix) :]
if key.startswith(old_prefix)
else key
): value
for key, value in self.channel_states.items()
}
return self
[docs]
def update_states(
self,
states: dict[str, Array],
params: dict[str, Array],
voltage: Array,
delta_t: float,
) -> dict[str, Array]:
"""Return the updated states.
Args:
states: All states of the compartment.
params: Parameters of the channel (conductances in `S/cm2`).
voltage: Voltage of the compartment in mV.
delta_t: The time step in ms.
Returns:
A dictionary of updated state values.
"""
raise NotImplementedError
[docs]
def compute_current(
self,
states: dict[str, Array],
params: dict[str, Array],
voltage: Array,
delta_t: float,
) -> Array:
"""Given channel states and voltage, return the current through the channel.
Args:
states: All states of the compartment.
params: Parameters of the channel (conductances in `S/cm2`).
voltage: Voltage of the compartment in mV.
delta_t: The time step in ms.
Returns:
Current in `uA/cm2`.
"""
raise NotImplementedError
[docs]
def init_state(
self,
states: dict[str, Array],
params: dict[str, Array],
voltage: Array,
delta_t: float,
) -> dict[str, Array]:
"""Initialize states of channel.
Args:
states: All states of the compartment.
params: Parameters of the channel (conductances in `S/cm2`).
voltage: Voltage of the compartment in mV.
delta_t: The time step in ms.
Returns:
A initial state that is written into ``module.nodes`` when the user runs
``module.init_states()``.
"""
return {}
[docs]
def init_params(
self,
states: dict[str, Array],
params: dict[str, Array],
voltage: Array,
delta_t: float,
) -> dict[str, Array]:
"""Initialize the maximal conductances given the temperature.
Args:
states: All states of the compartment.
params: Parameters of the channel (conductances in `S/cm2`).
voltage: Voltage of the compartment in mV.
delta_t: The time step in ms.
Returns:
Initial parameters that are written into ``module.nodes`` when the user runs
``module.init_params()``.
"""
return {}