jaxley.solver_gate.heaviside

Navigation

jaxley.solver_gate.heaviside#

heaviside(x, at_zero=1.0, grad_scale=10.0)[source]#

Compute the heaviside step function with a custom derivative.

Jaxley implementation of jax.numpy.heaviside, which includes a custom derivative.

The custom derivative is \(\frac{1}{(g|x| + 1)^2}\) where g is grad_scale. If you experience exploding or vanishing derivatives when using this function, try to change the value of grad_scale to remedy the problem.

Note while this function works for x and at_zero being jax arrays, you can only take the gradient of this function when both are scalar values.

Parameters:
  • x (Array | ndarray | bool | number | bool | int | float | complex) – Input array or scalar. complex dtype are not supported.

  • at_zero (Array | ndarray | bool | number | bool | int | float | complex) – Scalar or array. Specifies the return values when x is 0. complex dtype are not supported. x and at_zero must either have same shape or broadcast compatible.

  • grad_scale (float) – Specifies the flatness of the gradient curve. Larger values correspond to being closer to the ‘real’ gradient, however makes function more susceptible to exploding/vanishing gradients.

Returns:

An array containing the heaviside step function of x, promoting to inexact dtype.