jaxley.optimize.optimizer.TypeOptimizer

jaxley.optimize.optimizer.TypeOptimizer#

class TypeOptimizer(optimizer, optimizer_args, opt_params)[source]#

Bases: object

optax wrapper which allows different argument values for different params.

Parameters:
init(opt_params)[source]#

Initialize the optimizers. Equivalent to optax.optimizers.init().

Parameters:

opt_params (list[dict[str, Array]])

Return type:

list

update(gradient, opt_state)[source]#

Update the optimizers. Equivalent to optax.optimizers.update().

Parameters:
Return type:

tuple[list, list]