jax_to_numpy#

muse.utils.jax_to_numpy(jax_array)#

Convert a JAX array to a numpy.ndarray.

The array is copied back to host memory first so arrays on accelerator devices convert cleanly.

Parameters:

jax_array (jax.Array) – JAX array to convert.

Returns:

The converted NumPy array.

Return type:

numpy.ndarray