Source code for jaxley.integrate

# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>

from math import prod
from typing import Callable, Dict, List, Optional, Tuple, Union

import jax
import jax.numpy as jnp
import pandas as pd

from jaxley.modules import Module
from jaxley.utils.cell_utils import params_to_pstate
from jaxley.utils.jax_utils import nested_checkpoint_scan


[docs] def build_init_and_step_fn( module: Module, voltage_solver: str = "jaxley.dhs", solver: str = "bwd_euler", ) -> Tuple[Callable, Callable]: """Return ``init_fn`` and ``step_fn`` which initialize modules and run update steps. This method can be used to gain additional control over the simulation workflow. It exposes the ``step`` function, which can be used to perform step-by-step updates of the differential equations. Args: module: A `Module` object that e.g. a cell. voltage_solver: Voltage solver used in step. Defaults to "jaxley.stone". solver: ODE solver. Defaults to "bwd_euler". Returns: init_fn, step_fn: Functions that initialize the state and parameters, and perform a single integration step, respectively. Example usage ^^^^^^^^^^^^^ The following allows you to perform a step-by-step update of the differential equations. :: import jax.numpy as jnp import jaxley as jx from jaxley.integrate import build_init_and_step_fn t_max = 3.0 delta_t = 0.025 cell = jx.Cell() cell.record() cell.stimulate(jx.step_current(0, 1, 2, delta_t, t_max)) params = cell.get_parameters() cell.to_jax() rec_inds = cell.recordings.rec_index.to_numpy() rec_states = cell.recordings.state.to_numpy() externals = cell.externals.copy() external_inds = cell.external_inds.copy() # Uncomment this line if `data_stimuli` is not `None`. # externals, external_inds = add_stimuli(externals, external_inds, data_stimuli) # Uncomment this line if `data_clamps` is not `None`. # externals, external_inds = add_clamps(externals, external_inds, data_clamps) # Initialize. init_fn, step_fn = build_init_and_step_fn(cell) states, params = init_fn(params) recordings = [ states[rec_state][rec_ind][None] for rec_state, rec_ind in zip(rec_states, rec_inds) ] # Loop over the ODE. The `step_fn` can be jitted for improving speed. steps = int(t_max / delta_t) # Steps to integrate for step in range(steps): # Get externals at current timestep. externals_now = {} for key in externals.keys(): externals_now[key] = externals[key][:, step] states = step_fn( states, params, externals_now, external_inds, delta_t=delta_t ) recs = jnp.asarray( [ states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) ] ) recordings.append(recs) rec = jnp.stack(recordings, axis=0).T """ # Initialize the external inputs and their indices. external_inds = module.external_inds.copy() def init_fn( params: List[Dict[str, jnp.ndarray]], all_states: Optional[Dict] = None, param_state: Optional[List[Dict]] = None, delta_t: float = 0.025, ) -> Tuple[Dict, Dict]: """Initializes the parameters and states of the neuron model. Args: params: List of trainable parameters. all_states: State if alread initialized. Defaults to None. param_state: Parameters returned by `data_set`.. Defaults to None. delta_t: Step size. Defaults to 0.025. Returns: All states and parameters. """ # Make the `trainable_params` of the same shape as the `param_state`, such that # they can be processed together by `get_all_parameters`. pstate = params_to_pstate(params, module.indices_set_by_trainables) if param_state is not None: pstate += param_state all_params = module.get_all_parameters(pstate) all_states = ( module.get_all_states(pstate, all_params, delta_t) if all_states is None else all_states ) return all_states, all_params def step_fn( all_states: Dict, all_params: Dict, externals: Dict, external_inds: Dict = external_inds, delta_t: float = 0.025, ) -> Dict: """Performs a single integration step with step size delta_t. Args: all_states: Current state of the neuron model. all_params: Current parameters of the neuron model. externals: External inputs. external_inds: External indices. Defaults to `module.external_inds`. delta_t: Time step. Defaults to 0.025. Returns: Updated states. """ state = all_states state = module.step( state, delta_t, external_inds, externals, params=all_params, solver=solver, voltage_solver=voltage_solver, ) return state return init_fn, step_fn
def add_stimuli( externals: Dict, external_inds: Dict, data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None, ) -> Tuple[Dict, Dict]: """Extends the external inputs with the stimuli. Args: externals: Current external inputs. external_inds: Current external indices. data_stimuli: Additional data stimuli. Defaults to None. Returns: Updated external inputs and indices. """ # If stimulus is inserted, add it to the external inputs. if "i" in externals.keys() or data_stimuli is not None: if "i" in externals.keys(): if data_stimuli is not None: externals["i"] = jnp.concatenate([externals["i"], data_stimuli[1]]) external_inds["i"] = jnp.concatenate( [external_inds["i"], data_stimuli[2].index.to_numpy()] ) else: externals["i"] = data_stimuli[1] external_inds["i"] = data_stimuli[2].index.to_numpy() return externals, external_inds def add_clamps( externals: Dict, external_inds: Dict, data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None, ) -> Tuple[Dict, Dict]: """Adds clamps to the external inputs. Args: externals: Current external inputs. external_inds: Current external indices. data_clamps: Additional data clamps. Defaults to None. Returns: Updated external inputs and indices. """ # If a clamp is inserted, add it to the external inputs. if data_clamps is not None: state_name, clamps, inds = data_clamps if state_name in externals.keys(): externals[state_name] = jnp.concatenate([externals[state_name], clamps]) external_inds[state_name] = jnp.concatenate( [external_inds[state_name], inds.index.to_numpy()] ) else: externals[state_name] = clamps external_inds[state_name] = inds.index.to_numpy() return externals, external_inds
[docs] def integrate( module: Module, params: List[Dict[str, jnp.ndarray]] = [], *, param_state: Optional[List[Dict]] = None, data_stimuli: Optional[Tuple[jnp.ndarray, pd.DataFrame]] = None, data_clamps: Optional[Tuple[str, jnp.ndarray, pd.DataFrame]] = None, t_max: Optional[float] = None, delta_t: float = 0.025, solver: str = "bwd_euler", voltage_solver: str = "jaxley.dhs", checkpoint_lengths: Optional[List[int]] = None, all_states: Optional[Dict] = None, return_states: bool = False, ) -> jnp.ndarray: """ Solves ODE and simulates neuron model. Args: params: Trainable parameters returned by `get_parameters()`. param_state: Parameters returned by `data_set`. data_stimuli: Outputs of `.data_stimulate()`, only needed if stimuli change across function calls. data_clamps: Outputs of `.data_clamp()`, only needed if clamps change across function calls. t_max: Duration of the simulation in milliseconds. If `t_max` is greater than the length of the stimulus input, the stimulus will be padded at the end with zeros. If `t_max` is smaller, then the stimulus with be truncated. delta_t: Time step of the solver in milliseconds. solver: Which ODE solver to use. Either of ["fwd_euler", "bwd_euler", "crank_nicolson"]. voltage_solver: Algorithm to solve quasi-tridiagonal linear system describing the voltage equations. The different options only take effect when `solver` is either `bwd_euler` or `crank_nicolson`. The options for `voltage_solver` are `jaxley.dhs` and `jax.sparse`. For unbranched cables, we also support `jaxley.stone` (which has good performance for unbranched cables on GPU). checkpoint_lengths: Number of timesteps at every level of checkpointing. The `prod(checkpoint_lengths)` must be larger or equal to the desired number of simulated timesteps. Warning: the simulation is run for `prod(checkpoint_lengths)` timesteps, and the result is posthoc truncated to the desired simulation length. Therefore, a poor choice of `checkpoint_lengths` can lead to longer simulation time. If `None`, no checkpointing is applied. all_states: An optional initial state that was returned by a previous `jx.integrate(..., return_states=True)` run. Overrides potentially trainable initial states. return_states: If True, it returns all states such that the current state of the `Module` can be set with `set_states`. Example usage ^^^^^^^^^^^^^ The most simple usage is the following: :: cell = jx.Cell() v = jx.integrate(cell, t_max=10.0) If ``t_max`` is not passed, then you must have inserted a stimulus, and ``t_max`` will match the stimulus length. Customizing the solver ^^^^^^^^^^^^^^^^^^^^^^ If you use ``jx.integrate(..., voltage_solver="jaxley.dhs")``, we automatically choose between a CPU and a GPU optimized version. If you manually want to run the CPU-optimized version on GPU, do: :: cell._init_morph_jaxley_dhs_solve(allowed_nodes_per_level=1) v = jx.integrate(cell, voltage_solver="jaxley.dhs.cpu") To run the GPU-opotimized version on CPU, do: :: cell._init_morph_jaxley_dhs_solve(allowed_nodes_per_level=16) v = jx.integrate(cell, voltage_solver="jaxley.dhs.gpu") """ if voltage_solver == "jaxley.dhs": # Automatically infer the voltage solver. if module._solver_device in ["gpu", "tpu"]: voltage_solver = "jaxley.dhs.gpu" else: voltage_solver = "jaxley.dhs.cpu" assert module.initialized, "Module is not initialized, run `._initialize()`." module.to_jax() # Creates `.jaxnodes` from `.nodes` and `.jaxedges` from `.edges`. # Initialize the external inputs and their indices. externals = module.externals.copy() external_inds = module.external_inds.copy() # If stimulus is inserted, add it to the external inputs. externals, external_inds = add_stimuli(externals, external_inds, data_stimuli) # If a clamp is inserted, add it to the external inputs. externals, external_inds = add_clamps(externals, external_inds, data_clamps) if not externals.keys(): # No stimulus was inserted and no clamp was set. assert t_max is not None, ( "If no stimulus or clamp are inserted you have to specify the simulation" "duration at `jx.integrate(..., t_max=)`." ) for key in externals.keys(): externals[key] = externals[key].T # Shape `(time, num_stimuli)`. if module.recordings.empty: raise ValueError("No recordings are set. Please set them.") rec_inds = module.recordings.rec_index.to_numpy() rec_states = module.recordings.state.to_numpy() # Shorten or pad stimulus depending on `t_max`. if t_max is not None: t_max_steps = int(t_max // delta_t + 1) # Pad or truncate the stimulus. for key in externals.keys(): if t_max_steps > externals[key].shape[0]: if key == "i": pad = jnp.zeros( (t_max_steps - externals["i"].shape[0], externals["i"].shape[1]) ) externals["i"] = jnp.concatenate((externals["i"], pad)) else: raise NotImplementedError( "clamp must be at least as long as simulation." ) else: externals[key] = externals[key][:t_max_steps, :] init_fn, step_fn = build_init_and_step_fn( module, voltage_solver=voltage_solver, solver=solver ) all_states, all_params = init_fn(params, all_states, param_state, delta_t) def _body_fun(state, externals): state = step_fn(state, all_params, externals, external_inds, delta_t) recs = jnp.asarray( [ state[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) ] ) return state, recs # If necessary, pad the stimulus with zeros in order to simulate sufficiently long. # The total simulation length will be `prod(checkpoint_lengths)`. At the end, we # return only the first `nsteps_to_return` elements (plus the initial state). if externals: example_key = list(externals.keys())[0] nsteps_to_return = len(externals[example_key]) else: nsteps_to_return = t_max_steps if checkpoint_lengths is None: checkpoint_lengths = [nsteps_to_return] length = nsteps_to_return else: length = prod(checkpoint_lengths) size_difference = length - nsteps_to_return assert ( nsteps_to_return <= length ), "The desired simulation duration is longer than `prod(nested_length)`." if externals: dummy_external = jnp.zeros( (size_difference, externals[example_key].shape[1]) ) for key in externals.keys(): externals[key] = jnp.concatenate([externals[key], dummy_external]) # Record the initial state. init_recs = jnp.asarray( [ all_states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) ] ) init_recording = jnp.expand_dims(init_recs, axis=0) # Run simulation. all_states, recordings = nested_checkpoint_scan( _body_fun, all_states, externals, length=length, nested_lengths=checkpoint_lengths, ) recs = jnp.concatenate([init_recording, recordings[:nsteps_to_return]], axis=0).T return (recs, all_states) if return_states else recs