numpy_to_jax#
- muse.utils.numpy_to_jax(numpy_array, cuda_device=None)#
Convert a
numpy.ndarrayto a JAX array.Floating-point precision is capped at float32:
float64input is downcast tofloat32, while narrower dtypes (float16, integers) are left as-is.