Source code for jaxley.solver_gate
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
import jax.numpy as jnp
from jax import custom_gradient
from jax.typing import ArrayLike
[docs]
def save_exp(x, max_value: float = 20.0):
"""Clip the input to a maximum value and return its exponential."""
x = jnp.clip(x, a_max=max_value)
return jnp.exp(x)
[docs]
def solve_gate_implicit(
gating_state: ArrayLike,
dt: float,
alpha: ArrayLike,
beta: ArrayLike,
):
a_m = gating_state + dt * alpha
b_m = 1.0 + dt * alpha + dt * beta
return a_m / b_m
[docs]
def solve_gate_exponential(
x: ArrayLike,
dt: float,
alpha: ArrayLike,
beta: ArrayLike,
):
tau = 1 / (alpha + beta)
xinf = alpha * tau
return exponential_euler(x, dt, xinf, tau)
[docs]
def exponential_euler(
x: ArrayLike,
dt: float,
x_inf: ArrayLike,
x_tau: ArrayLike,
):
"""An exact solver for the linear dynamical system `dx = -(x - x_inf) / x_tau`."""
exp_term = save_exp(-dt / x_tau)
return x * exp_term + x_inf * (1.0 - exp_term)
[docs]
def heaviside(x: ArrayLike, at_zero: ArrayLike = 1.0, grad_scale: float = 10.0):
"""Compute the heaviside step function with a custom derivative.
Jaxley implementation of ``jax.numpy.heaviside``, which includes a custom
derivative.
The custom derivative is $\\frac{1}{(g|x| + 1)^2}$ where g is ``grad_scale``.
If you experience exploding or vanishing derivatives when using this function,
try to change the value of `grad_scale` to remedy the problem.
Note while this function works for ``x`` and ``at_zero`` being jax arrays,
you can only take the gradient of this function when both are scalar values.
Args:
x: Input array or scalar. ``complex`` dtype are not supported.
at_zero: Scalar or array. Specifies the return values when ``x`` is ``0``.
``complex`` dtype are not supported. ``x`` and ``at_zero`` must either
have same shape or broadcast compatible.
grad_scale: Specifies the flatness of the gradient curve. Larger values
correspond to being closer to the 'real' gradient, however makes function
more susceptible to exploding/vanishing gradients.
Returns:
An array containing the heaviside step function of ``x``, promoting to
inexact dtype.
"""
@custom_gradient
def _heaviside_custom(x):
return (
jnp.heaviside(x, at_zero),
lambda g: (g / (grad_scale * jnp.abs(x) + 1.0) ** 2),
)
return _heaviside_custom(x)