Source code for jaxley.optimize.optimizer
# This file is part of Jaxley, a differentiable neuroscience simulator. Jaxley is
# licensed under the Apache License Version 2.0, see <https://www.apache.org/licenses/>
from collections.abc import Callable
from typing import Any
from jax import Array
from jax.typing import ArrayLike
[docs]
class TypeOptimizer:
"""`optax` wrapper which allows different argument values for different params."""
def __init__(
self,
optimizer: Callable,
optimizer_args: dict[str, Any],
opt_params: list[dict[str, Array]],
):
"""Create the optimizers.
This requires access to `opt_params` in order to know how many optimizers
should be created. It creates `len(opt_params)` optimizers.
Example usage:
```
lrs = {"HH_gNa": 0.01, "radius": 1.0}
optimizer = TypeOptimizer(lambda lr: optax.adam(lr), lrs, opt_params)
opt_state = optimizer.init(opt_params)
```
```
optimizer_args = {"HH_gNa": [0.01, 0.4], "radius": [1.0, 0.8]}
optimizer = TypeOptimizer(
lambda args: optax.sgd(args[0], momentum=args[1]),
optimizer_args,
opt_params
)
opt_state = optimizer.init(opt_params)
```
Args:
optimizer: A Callable that takes the learning rate and returns the
`optax.optimizer` which should be used.
optimizer_args: The arguments for different kinds of parameters.
Each item of the dictionary will be passed to the `Callable` passed to
`optimizer`.
opt_params: The parameters to be optimized. The exact values are not used,
only the number of elements in the list and the key of each dict.
"""
self.base_optimizer = optimizer
self.optimizers = []
for params in opt_params:
names = list(params.keys())
assert len(names) == 1, "Multiple parameters were added at once."
name = names[0]
optimizer = self.base_optimizer(optimizer_args[name])
self.optimizers.append({name: optimizer})
[docs]
def init(self, opt_params: list[dict[str, Array]]) -> list:
"""Initialize the optimizers. Equivalent to `optax.optimizers.init()`."""
opt_states = []
for params, optimizer in zip(opt_params, self.optimizers):
name = list(optimizer.keys())[0]
opt_state = optimizer[name].init(params)
opt_states.append(opt_state)
return opt_states
[docs]
def update(self, gradient: Array, opt_state: list) -> tuple[list, list]:
"""Update the optimizers. Equivalent to `optax.optimizers.update()`."""
all_updates = []
new_opt_states = []
for grad, state, opt in zip(gradient, opt_state, self.optimizers):
name = list(opt.keys())[0]
updates, new_opt_state = opt[name].update(grad, state)
all_updates.append(updates)
new_opt_states.append(new_opt_state)
return all_updates, new_opt_states