jaxley.integrate

Navigation

jaxley.integrate#

integrate(module, params=[], *, param_state=None, data_stimuli=None, data_clamps=None, t_max=None, delta_t=0.025, solver='bwd_euler', voltage_solver='jaxley.dhs', checkpoint_lengths=None, all_states=None, return_states=False)[source]#

Solves ODE and simulates neuron model.

Parameters:
  • params (list[dict[str, Array]]) – Trainable parameters returned by get_parameters().

  • param_state (list[dict] | None) – Parameters returned by data_set.

  • data_stimuli (tuple[Array | ndarray | bool | number | bool | int | float | complex, DataFrame] | None) – Outputs of .data_stimulate(), only needed if stimuli change across function calls.

  • data_clamps (tuple[str, Array | ndarray | bool | number | bool | int | float | complex, DataFrame] | None) – Outputs of .data_clamp(), only needed if clamps change across function calls.

  • t_max (float | None) – Duration of the simulation in milliseconds. If t_max is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If t_max is smaller, then the stimulus with be truncated.

  • delta_t (float) – Time step of the solver in milliseconds.

  • solver (str) – Which ODE solver to use. Either of [“fwd_euler”, “bwd_euler”, “crank_nicolson”].

  • voltage_solver (str) – Algorithm to solve quasi-tridiagonal linear system describing the voltage equations. The different options only take effect when solver is either bwd_euler or crank_nicolson. The options for voltage_solver are jaxley.dhs and jax.sparse. For unbranched cables, we also support jaxley.stone (which has good performance for unbranched cables on GPU).

  • checkpoint_lengths (List[int] | None) – Number of timesteps at every level of checkpointing. The prod(checkpoint_lengths) must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for prod(checkpoint_lengths) timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of checkpoint_lengths can lead to longer simulation time. If None, no checkpointing is applied.

  • all_states (Dict | None) – An optional initial state that was returned by a previous jx.integrate(…, return_states=True) run. Overrides potentially trainable initial states.

  • return_states (bool) – If True, it returns all states such that the current state of the Module can be set with set_states.

  • module (Module)

Return type:

Array

Example usage#

The most simple usage is the following:

cell = jx.Cell()
v = jx.integrate(cell, t_max=10.0)

If t_max is not passed, then you must have inserted a stimulus, and t_max will match the stimulus length.

Customizing the solver#

If you use jx.integrate(..., voltage_solver="jaxley.dhs"), we automatically choose between a CPU and a GPU optimized version. If you manually want to run the CPU-optimized version on GPU, do:

cell._init_morph_jaxley_dhs_solve(allowed_nodes_per_level=1)
v = jx.integrate(cell, voltage_solver="jaxley.dhs.cpu")

To run the GPU-opotimized version on CPU, do:

cell._init_morph_jaxley_dhs_solve(allowed_nodes_per_level=16)
v = jx.integrate(cell, voltage_solver="jaxley.dhs.gpu")