jaxley.integrate.build_init_and_step_fn#
- build_init_and_step_fn(module, voltage_solver='jaxley.dhs', solver='bwd_euler')[source]#
Return
init_fnandstep_fnwhich initialize modules and run update steps.This method can be used to gain additional control over the simulation workflow. It exposes the
stepfunction, which can be used to perform step-by-step updates of the differential equations.- Parameters:
- Returns:
- Functions that initialize the state and parameters, and
perform a single integration step, respectively.
- Return type:
init_fn, step_fn
Example usage#
The following allows you to perform a step-by-step update of the differential equations.
import jax.numpy as jnp import jaxley as jx from jaxley.integrate import build_init_and_step_fn t_max = 3.0 delta_t = 0.025 cell = jx.Cell() cell.record() cell.stimulate(jx.step_current(0, 1, 2, delta_t, t_max)) params = cell.get_parameters() cell.to_jax() rec_inds = cell.recordings.rec_index.to_numpy() rec_states = cell.recordings.state.to_numpy() externals = cell.externals.copy() external_inds = cell.external_inds.copy() # Uncomment this line if `data_stimuli` is not `None`. # externals, external_inds = add_stimuli(externals, external_inds, data_stimuli) # Uncomment this line if `data_clamps` is not `None`. # externals, external_inds = add_clamps(externals, external_inds, data_clamps) # Initialize. init_fn, step_fn = build_init_and_step_fn(cell) states, params = init_fn(params) recordings = [ states[rec_state][rec_ind][None] for rec_state, rec_ind in zip(rec_states, rec_inds) ] # Loop over the ODE. The `step_fn` can be jitted for improving speed. steps = int(t_max / delta_t) # Steps to integrate for step in range(steps): # Get externals at current timestep. externals_now = {} for key in externals.keys(): externals_now[key] = externals[key][:, step] states = step_fn( states, params, externals_now, external_inds, delta_t=delta_t ) recs = jnp.asarray( [ states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) ] ) recordings.append(recs) rec = jnp.stack(recordings, axis=0).T