jax_to_numpy#
- muse.utils.jax_to_numpy(jax_array)#
Convert a JAX array to a
numpy.ndarray.The array is copied back to host memory first so arrays on accelerator devices convert cleanly.
- Parameters:
jax_array (
jax.Array) – JAX array to convert.- Returns:
The converted NumPy array.
- Return type: