jaxley.utils.dynamics.build_dynamic_state_utils

jaxley.utils.dynamics.build_dynamic_state_utils#

build_dynamic_state_utils(module)[source]#

Return functions which extract the dynamic (ODE) states of a jx.Module.

These utility functions are meant to be used together with jx.integrate.build_init_and_step_fn. The init_fn returned by build_init_and_step_fn returns an all_states, which is a dictionary of all states, including observables: the voltages at branchpoints, the channel and synapse currents, and NaN elements for channel states which do not exist in certain compartments. The utility functions returned by build_utils_for_dynamic_states() modify the all_states as follows:

  • They remove all channel currents, synapse currents, and branchpoint voltages (which can be computed from compartment voltages). Additionally, if states are only defined on a subset of compartments, the NaN padding is removed. As such, only “true” dynamic states remain. This is handled by the returned functions remove_observables and add_observables.

  • They return the states as a flat array. This allows easier interoperability with frameworks such as dynamax. This is handled by the returned functions flatten and unflatten.

Warning

If a membrane current is used as a state by another channel (for example, the calcium current i_ca used in the channel_states of a calcium-dependent potassium channel), then this current will be included in the returned “true” ODE states. Similarly, if there are channel_states that are directly computed based on other states (e.g., the calcium reversal potential via the nernst equation), they are also considered “true” ODE states. Please remove such states manually if desired.

Parameters:

module – A Module object, e.g., a jx.Cell.

Returns:

  • remove_observables(all_states)

    Callable which removes the membrane currents, synaptic currents, branchpoint voltages and NaN padding from the states dict. The returned states only include true “dynamic” states.

    • Args:

      • all_states (Dict[str, Array]): All states of the system which can be recorded.

    • Returns:

      • Dynamic states of the system (Dict[str, Array]).

  • add_observables(dynamic_states_pytree, all_params, delta_t)

    Callable which adds membrane currents, synaptic currents, and branchpoint voltages to the states dictionary.

    • Args:

      • dynamic_states_pytree (Dict[str, Array])

      • all_params (Dict[str, Array])

      • delta_t (float).

    • Returns:

      • All states of the system which can be recorded (Dict[str, Array]).

  • flatten(dynamic_states_pytree)

    Callable which flattens dynamic states as a pytree into a jnp.Array.

    • Args:

      • dynamic_states_pytree (Dict[str, Array]): All dynamic states.

    • Returns:

      • Dynamic states of the system as a flattened Array (Array).

  • unflatten(*args)

    Callable which converts the state vector back to a pytree.

    • Args:

      • The dynamic states of the system as a flat jax array (Array).

    • Returns:

      • Dynamic states as a dict of Arrays (Dict[str, Array]).

Return type:

Tuple[Callable, Callable, Callable, Callable]

Example usage#

Example 1: Use the functions returned by build_dynamic_state_utils to build a vector of dynamics states. Then convert the vector back to the all_states dictionary.

import jaxley as jx
from jaxley.integrate import build_init_and_step_fn
from jaxley.utils.dynamics import build_dynamic_state_utils

cell = jx.Cell()
params = cell.get_parameters()

init_fn, step_fn = build_init_and_step_fn(cell)
remove_observables, add_observables, flatten, unflatten = build_dynamic_state_utils(cell)

all_states, all_params = init_fn(params)

dynamic_states = flatten(remove_observables(all_states))
recovered_all_states = add_observables(unflatten(dynamic_states), all_params, delta_t=0.025)

Example 2: Build a step_dynamics function and use it to compute the Jacobian of a single step.

from jax import jacfwd

import jaxley as jx
from jaxley.integrate import build_init_and_step_fn
from jaxley.utils.dynamics import build_dynamic_state_utils
from jaxley.channels import Leak

comp = jx.Compartment()
branch = jx.Branch(comp, 2)
cell = jx.Cell(branch, parents=[-1, 0, 0])
cell.insert(Leak())
params = cell.get_parameters()

externals = cell.externals.copy()
external_inds = cell.external_inds.copy()

init_fn, step_fn = build_init_and_step_fn(cell)
remove_observables, add_observables, flatten, unflatten = build_dynamic_state_utils(cell)

all_states, all_params = init_fn(params)
dynamic_states = flatten(remove_observables(all_states))

def step_dynamics(dynamic_states, all_params, externals, external_inds, delta_t):
    all_states = add_observables(unflatten(dynamic_states), all_params, delta_t)
    all_states = step_fn(
        all_states, all_params, externals, external_inds, delta_t=delta_t
    )
    dynamic_states = flatten(remove_observables(all_states))
    return dynamic_states

jacobian = jacfwd(step_dynamics)(dynamic_states, all_params, externals, external_inds, delta_t=0.025)

Example 3: Build a loss function based on input and parameters.

import jax.numpy as jnp

import jaxley as jx
from jaxley.integrate import build_init_and_step_fn
from jaxley.utils.dynamics import build_dynamic_state_utils
from jaxley.channels import Leak

cell = jx.Cell()
cell.insert(Leak())
t_max = 3.0
delta_t = 0.025

cell.record()
cell.stimulate(jx.step_current(0, 1, 2, delta_t, t_max))

rec_inds = cell.recordings.rec_index.to_numpy()
rec_states = cell.recordings.state.to_numpy()
externals = cell.externals.copy()
external_inds = cell.external_inds.copy()

cell.make_trainable("radius")
params = cell.get_parameters()

init_fn, step_fn = build_init_and_step_fn(cell)
remove_observables, add_observables, flatten, unflatten = build_dynamic_state_utils(cell)

def init_dynamics(params, param_state):
    all_states, all_params = init_fn(params, None, param_state)
    recordings = [jnp.asarray(
        [
            all_states[rec_state][rec_ind]
            for rec_state, rec_ind in zip(rec_states, rec_inds)
        ]
    )]
    dynamic_states = flatten(remove_observables(all_states))
    return dynamic_states, all_params, recordings

def step_dynamics(dynamic_states, all_params, externals, external_inds):
    all_states = add_observables(unflatten(dynamic_states), all_params, 0.025)
    all_states = step_fn(
        all_states, all_params, externals, external_inds, delta_t=delta_t
    )
    recs = jnp.asarray(
        [
            all_states[rec_state][rec_ind]
            for rec_state, rec_ind in zip(rec_states, rec_inds)
        ]
    )
    dynamic_states = flatten(remove_observables(all_states))
    return dynamic_states, recs

def loss_fn(params, param_state_value):
    param_state = cell.data_set("Leak_gLeak", param_state_value, None)
    cell.to_jax()
    dynamic_states, all_params, recordings = init_dynamics(params, param_state)
    steps = int(t_max / delta_t)
    for step in range(steps):
        externals_now = {}
        for key in externals.keys():
            externals_now[key] = externals[key][:, step]
        dynamic_states, recs = step_dynamics(dynamic_states, all_params, externals_now, external_inds)
        recordings.append(recs)
    return jnp.mean(jnp.stack(recordings, axis=0).T)

loss = loss_fn(params, 1e-4)