I am trying to run a massive array (of minimum dimension ~5 billion complex double numbers) through a data processer in JAX on 1 NVIDIA GPU. If I try to load the dataset normally, I predictably run into a RESOURCE_EXHAUSTED error. How do I properly read a large array into multiple devices for a single process?
Currently, I am trying jax.device_put(jnp.asarray(HDF5file),sharding)
, where HDF5file
is the file, and sharding
is a sharding scheme (which I am fairly confident is correct since it executes properly on smaller arrays), and the resource exhausted error still shows up. Any help would be appreciated