jaxley.integrate.build_init_and_step_fn

jaxley.integrate.build_init_and_step_fn#

build_init_and_step_fn(module, voltage_solver='jaxley.dhs', solver='bwd_euler')[source]#

Return init_fn and step_fn which initialize modules and run update steps.

This method can be used to gain additional control over the simulation workflow. It exposes the step function, which can be used to perform step-by-step updates of the differential equations.

Parameters:
  • module (Module) – A Module object that e.g. a cell.

  • voltage_solver (str) – Voltage solver used in step. Defaults to “jaxley.stone”.

  • solver (str) – ODE solver. Defaults to “bwd_euler”.

Returns:

Functions that initialize the state and parameters, and

perform a single integration step, respectively.

Return type:

init_fn, step_fn

Example usage#

The following allows you to perform a step-by-step update of the differential equations.

import jax.numpy as jnp
import jaxley as jx
from jaxley.integrate import build_init_and_step_fn

t_max = 3.0
delta_t = 0.025

cell = jx.Cell()
cell.record()
cell.stimulate(jx.step_current(0, 1, 2, delta_t, t_max))
params = cell.get_parameters()

cell.to_jax()
rec_inds = cell.recordings.rec_index.to_numpy()
rec_states = cell.recordings.state.to_numpy()
externals = cell.externals.copy()
external_inds = cell.external_inds.copy()

# Uncomment this line if `data_stimuli` is not `None`.
# externals, external_inds = add_stimuli(externals, external_inds, data_stimuli)

# Uncomment this line if `data_clamps` is not `None`.
# externals, external_inds = add_clamps(externals, external_inds, data_clamps)

# Initialize.
init_fn, step_fn = build_init_and_step_fn(cell)
states, params = init_fn(params)
recordings = [
    states[rec_state][rec_ind][None]
    for rec_state, rec_ind in zip(rec_states, rec_inds)
]

# Loop over the ODE. The `step_fn` can be jitted for improving speed.
steps = int(t_max / delta_t)  # Steps to integrate
for step in range(steps):
    # Get externals at current timestep.
    externals_now = {}
    for key in externals.keys():
        externals_now[key] = externals[key][:, step]

    states = step_fn(
        states, params, externals_now, external_inds, delta_t=delta_t
    )
    recs = jnp.asarray(
        [
            states[rec_state][rec_ind]
            for rec_state, rec_ind in zip(rec_states, rec_inds)
        ]
    )
    recordings.append(recs)

rec = jnp.stack(recordings, axis=0).T