jaxley.solver_gate.heaviside#
- heaviside(x, at_zero=1.0, grad_scale=10.0)[source]#
Compute the heaviside step function with a custom derivative.
Jaxley implementation of
jax.numpy.heaviside, which includes a custom derivative.The custom derivative is \(\frac{1}{(g|x| + 1)^2}\) where g is
grad_scale. If you experience exploding or vanishing derivatives when using this function, try to change the value of grad_scale to remedy the problem.Note while this function works for
xandat_zerobeing jax arrays, you can only take the gradient of this function when both are scalar values.- Parameters:
x (Array | ndarray | bool | number | bool | int | float | complex) – Input array or scalar.
complexdtype are not supported.at_zero (Array | ndarray | bool | number | bool | int | float | complex) – Scalar or array. Specifies the return values when
xis0.complexdtype are not supported.xandat_zeromust either have same shape or broadcast compatible.grad_scale (float) – Specifies the flatness of the gradient curve. Larger values correspond to being closer to the ‘real’ gradient, however makes function more susceptible to exploding/vanishing gradients.
- Returns:
An array containing the heaviside step function of
x, promoting to inexact dtype.