Recommendation to read large array into multiple gpus for a single process
I am trying to run a massive array (of minimum dimension ~5 billion complex double numbers) through a data processer in JAX on 1 NVIDIA GPU. If I try to load the dataset normally, I predictably run into a RESOURCE_EXHAUSTED error. How do I properly read a large array into multiple devices for a single process?