jaxley.integrate.build_init_and_step_fn

jaxley.integrate.build_init_and_step_fn#

build_init_and_step_fn(module, voltage_solver='jaxley.dhs', solver='bwd_euler')[source]#

Return init_fn and step_fn which initialize modules and run update steps.

This method can be used to gain additional control over the simulation workflow. It exposes a step function, which can be used to perform step-by-step updates of the differential equations.

Parameters:
  • module (Module) – A jx.Module object that, for example a jx.Cell.

  • voltage_solver (str) – Voltage solver used in step. Defaults to “jaxley.dhs”.

  • solver (str) – ODE solver. One of { “bwd_euler” | “crank_nicolson” “| “fwd_euler” | “exp_euler” }.

Returns:

  • init_fn(params, all_states=None, param_state=None, delta_t=0.025)

    Callable which initializes the states and parameters.

    • Args:

      • params (list[dict]): returned by .get_parameters().

      • all_states (dict | None = None): typically None.

      • param_state (list[dict] | None = None): returned by .data_set().

      • delta_t (float = 0.025): the time step.

    • Returns:

      • all_states (dict).

      • all_params (dict), which can be passed to the step_fn.

  • step_fn(all_states, all_params, external_inds, externals, delta_t=0.025)

    Callable which performs a single integration step.

    • Args:

      • all_states (dict): returned by init_fn().

      • all_params (dict): returned by init_fn().

      • externals (dict): obtained with module.externals.copy() but using only the external input at the current time step (see examples below).

      • external_inds (dict): obtained with module.external_inds.copy().

      • delta_t (float): the time step.

    • Returns:

      • Updated all_states (dict).

Return type:

Tuple[Callable, Callable]

Example usage#

The following allows you to perform a step-by-step update of the differential equations.

import jax.numpy as jnp
import jaxley as jx
from jaxley.integrate import build_init_and_step_fn

t_max = 3.0
delta_t = 0.025

cell = jx.Cell()
cell.record()
cell.stimulate(jx.step_current(0, 1, 2, delta_t, t_max))
params = cell.get_parameters()

cell.to_jax()
rec_inds = cell.recordings.rec_index.to_numpy()
rec_states = cell.recordings.state.to_numpy()
externals = cell.externals.copy()
external_inds = cell.external_inds.copy()

# Uncomment this line if `data_stimuli` is not `None`.
# externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)

# Uncomment this line if `data_clamps` is not `None`.
# externals, external_inds = add_clamps(externals, external_inds, data_clamps)

# Initialize.
init_fn, step_fn = build_init_and_step_fn(cell)
states, params = init_fn(params)
recordings = [jnp.asarray(
    [
        all_states[rec_state][rec_ind]
        for rec_state, rec_ind in zip(rec_states, rec_inds)
    ]
)]

# Loop over the ODE. The `step_fn` can be jitted for improving speed.
steps = int(t_max / delta_t)  # Steps to integrate
for step in range(steps):
    # Get externals at current timestep.
    externals_now = {}
    for key in externals.keys():
        externals_now[key] = externals[key][:, step]

    states = step_fn(
        states, params, externals_now, external_inds, delta_t=delta_t
    )
    recs = jnp.asarray(
        [
            states[rec_state][rec_ind]
            for rec_state, rec_ind in zip(rec_states, rec_inds)
        ]
    )
    recordings.append(recs)

rec = jnp.stack(recordings, axis=0).T