jaxley.utils.dynamics.build_dynamic_state_utils#
- build_dynamic_state_utils(module)[source]#
Return functions which extract the dynamic (ODE) states of a
jx.Module.These utility functions are meant to be used together with
jx.integrate.build_init_and_step_fn. Theinit_fnreturned bybuild_init_and_step_fnreturns anall_states, which is a dictionary of all states, including observables: the voltages at branchpoints, the channel and synapse currents, and NaN elements for channel states which do not exist in certain compartments. The utility functions returned bybuild_utils_for_dynamic_states()modify theall_statesas follows:They remove all channel currents, synapse currents, and branchpoint voltages (which can be computed from compartment voltages). Additionally, if states are only defined on a subset of compartments, the NaN padding is removed. As such, only “true” dynamic states remain. This is handled by the returned functions
remove_observablesandadd_observables.They return the states as a flat array. This allows easier interoperability with frameworks such as
dynamax. This is handled by the returned functionsflattenandunflatten.
Warning
If a membrane current is used as a state by another channel (for example, the calcium current
i_caused in thechannel_statesof a calcium-dependent potassium channel), then this current will be included in the returned “true” ODE states. Similarly, if there arechannel_statesthat are directly computed based on other states (e.g., the calcium reversal potential via the nernst equation), they are also considered “true” ODE states. Please remove such states manually if desired.- Parameters:
module – A
Moduleobject, e.g., ajx.Cell.- Returns:
remove_observables(all_states)Callable which removes the membrane currents, synaptic currents, branchpoint voltages and NaN padding from the states dict. The returned states only include true “dynamic” states.
Args:
all_states(Dict[str, Array]): All states of the system which can be recorded.
Returns:
Dynamic states of the system (Dict[str, Array]).
add_observables(dynamic_states_pytree, all_params, delta_t)Callable which adds membrane currents, synaptic currents, and branchpoint voltages to the states dictionary.
Args:
dynamic_states_pytree(Dict[str, Array])all_params(Dict[str, Array])delta_t(float).
Returns:
All states of the system which can be recorded (Dict[str, Array]).
flatten(dynamic_states_pytree)Callable which flattens dynamic states as a pytree into a jnp.Array.
Args:
dynamic_states_pytree(Dict[str, Array]): All dynamic states.
Returns:
Dynamic states of the system as a flattened Array (Array).
unflatten(*args)Callable which converts the state vector back to a pytree.
Args:
The dynamic states of the system as a flat jax array (Array).
Returns:
Dynamic states as a dict of Arrays (Dict[str, Array]).
- Return type:
Example usage#
Example 1: Use the functions returned by
build_dynamic_state_utilsto build a vector of dynamics states. Then convert the vector back to the all_states dictionary.import jaxley as jx from jaxley.integrate import build_init_and_step_fn from jaxley.utils.dynamics import build_dynamic_state_utils cell = jx.Cell() params = cell.get_parameters() init_fn, step_fn = build_init_and_step_fn(cell) remove_observables, add_observables, flatten, unflatten = build_dynamic_state_utils(cell) all_states, all_params = init_fn(params) dynamic_states = flatten(remove_observables(all_states)) recovered_all_states = add_observables(unflatten(dynamic_states), all_params, delta_t=0.025)
Example 2: Build a
step_dynamicsfunction and use it to compute the Jacobian of a single step.from jax import jacfwd import jaxley as jx from jaxley.integrate import build_init_and_step_fn from jaxley.utils.dynamics import build_dynamic_state_utils from jaxley.channels import Leak comp = jx.Compartment() branch = jx.Branch(comp, 2) cell = jx.Cell(branch, parents=[-1, 0, 0]) cell.insert(Leak()) params = cell.get_parameters() externals = cell.externals.copy() external_inds = cell.external_inds.copy() init_fn, step_fn = build_init_and_step_fn(cell) remove_observables, add_observables, flatten, unflatten = build_dynamic_state_utils(cell) all_states, all_params = init_fn(params) dynamic_states = flatten(remove_observables(all_states)) def step_dynamics(dynamic_states, all_params, externals, external_inds, delta_t): all_states = add_observables(unflatten(dynamic_states), all_params, delta_t) all_states = step_fn( all_states, all_params, externals, external_inds, delta_t=delta_t ) dynamic_states = flatten(remove_observables(all_states)) return dynamic_states jacobian = jacfwd(step_dynamics)(dynamic_states, all_params, externals, external_inds, delta_t=0.025)
Example 3: Build a loss function based on input and parameters.
import jax.numpy as jnp import jaxley as jx from jaxley.integrate import build_init_and_step_fn from jaxley.utils.dynamics import build_dynamic_state_utils from jaxley.channels import Leak cell = jx.Cell() cell.insert(Leak()) t_max = 3.0 delta_t = 0.025 cell.record() cell.stimulate(jx.step_current(0, 1, 2, delta_t, t_max)) rec_inds = cell.recordings.rec_index.to_numpy() rec_states = cell.recordings.state.to_numpy() externals = cell.externals.copy() external_inds = cell.external_inds.copy() cell.make_trainable("radius") params = cell.get_parameters() init_fn, step_fn = build_init_and_step_fn(cell) remove_observables, add_observables, flatten, unflatten = build_dynamic_state_utils(cell) def init_dynamics(params, param_state): all_states, all_params = init_fn(params, None, param_state) recordings = [jnp.asarray( [ all_states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) ] )] dynamic_states = flatten(remove_observables(all_states)) return dynamic_states, all_params, recordings def step_dynamics(dynamic_states, all_params, externals, external_inds): all_states = add_observables(unflatten(dynamic_states), all_params, 0.025) all_states = step_fn( all_states, all_params, externals, external_inds, delta_t=delta_t ) recs = jnp.asarray( [ all_states[rec_state][rec_ind] for rec_state, rec_ind in zip(rec_states, rec_inds) ] ) dynamic_states = flatten(remove_observables(all_states)) return dynamic_states, recs def loss_fn(params, param_state_value): param_state = cell.data_set("Leak_gLeak", param_state_value, None) cell.to_jax() dynamic_states, all_params, recordings = init_dynamics(params, param_state) steps = int(t_max / delta_t) for step in range(steps): externals_now = {} for key in externals.keys(): externals_now[key] = externals[key][:, step] dynamic_states, recs = step_dynamics(dynamic_states, all_params, externals_now, external_inds) recordings.append(recs) return jnp.mean(jnp.stack(recordings, axis=0).T) loss = loss_fn(params, 1e-4)