jaxley.integrate.build_init_and_step_fn#
- build_init_and_step_fn(module, voltage_solver='jaxley.dhs', solver='bwd_euler')[source]#
Return
init_fnandstep_fnwhich initialize modules and run update steps.This method can be used to gain additional control over the simulation workflow. It exposes a step function, which can be used to perform step-by-step updates of the differential equations.
- Parameters:
- Returns:
init_fn(params, all_states=None, param_state=None, delta_t=0.025)Callable which initializes the states and parameters.
Args:
params(list[dict]): returned by .get_parameters().all_states(dict | None = None): typically None.param_state(list[dict] | None = None): returned by .data_set().delta_t(float = 0.025): the time step.
Returns:
all_states(dict).all_params(dict), which can be passed to the step_fn.
step_fn(all_states, all_params, external_inds, externals, delta_t=0.025)Callable which performs a single integration step.
Args:
all_states(dict): returned by init_fn().all_params(dict): returned by init_fn().externals(dict): obtained with module.externals.copy() but using only the external input at the current time step (see examples below).external_inds(dict): obtained with module.external_inds.copy().delta_t(float): the time step.
Returns:
Updated
all_states(dict).
- Return type:
Example usage#
The following allows you to perform a step-by-step update of the differential equations.
import jax.numpy as jnp import jaxley as jx from jaxley.integrate import build_init_and_step_fn t_max = 3.0 delta_t = 0.025 cell = jx.Cell() cell.record() cell.stimulate(jx.step_current(0, 1, 2, delta_t, t_max)) params = cell.get_parameters() cell.to_jax() rec_inds = cell.recordings.rec_index.to_numpy() rec_states = cell.recordings.state.to_numpy() externals = cell.externals.copy() external_inds = cell.external_inds.copy() # Uncomment this line if `data_stimuli` is not `None`. # externals, external_inds = add_stimuli(externals, external_inds, data_stimuli) # Uncomment this line if `data_clamps` is not `None`. # externals, external_inds = add_clamps(externals, external_inds, data_clamps) # Initialize. init_fn, step_fn = build_init_and_step_fn(cell) states, params = init_fn(params) recordings = [jnp.asarray( [ all_states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) ] )] # Loop over the ODE. The `step_fn` can be jitted for improving speed. steps = int(t_max / delta_t) # Steps to integrate for step in range(steps): # Get externals at current timestep. externals_now = {} for key in externals.keys(): externals_now[key] = externals[key][:, step] states = step_fn( states, params, externals_now, external_inds, delta_t=delta_t ) recs = jnp.asarray( [ states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) ] ) recordings.append(recs) rec = jnp.stack(recordings, axis=0).T