numpy_to_jax#

muse.utils.numpy_to_jax(numpy_array, cuda_device=None)#

Convert a numpy.ndarray to a JAX array.

Floating-point precision is capped at float32: float64 input is downcast to float32, while narrower dtypes (float16, integers) are left as-is.

Parameters:
  • numpy_array (ndarray) – The array to convert.

  • cuda_device (int | None) – If provided, transfer the array to the specified CUDA device. If omitted, keep the old CPU default used by the previous tensor bridge.

Returns:

The converted JAX array.

Return type:

jax.Array