enter image description here
I try to use 4 GPUs to solve a Physic-informed nerual networks(PINN) problem. And I find that when I use one GPU, the utlization of GPU can reach 100% and the training speed is high (200it/s), but when I use 4 GPUs by implementing sharding and jax.jit strategies as shown in https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#way-batch-data-parallelism
enter image description here
I shard all my data to different GPUs and replciate the parameters and state of function as shown in guidelines of auto parallelization. And I find my 4 GPUs utilization are quite low, as shown before 10%, and the training speed is quite low (20 it/s).
I am curious about the reason. I know there is alternative method for parallelize my model such as jax.sharp_map or jax.pmap. But I try the method of jax.pmap, mentioned in https://medium.com/@save.our.thoughts/exploring-parallel-strategies-with-jax-b1adcb9ee0d6.
import functools
For example:
import functools
# Remember that the 'G' is just an arbitrary string label used
# to later tell 'jax.lax.pmean' which axis to reduce over. Here, we call it
# 'G', but could have used anything, so long as 'pmean' used the same.
@functools.partial(jax.pmap, axis_name='G')
def update(params: Params, x: jnp.ndarray, y: jnp.ndarray):
# Compute the gradients on the given minibatch (individually on each device)
loss, grads = jax.value_and_grad(loss_fn)(params, x, y)
# Combine the gradient across all devices (by taking their mean)
grads = jax.lax.pmean(grads, axis_name='G')
# Also combine the loss. Unnecessary for the update, but useful for logging
loss = jax.lax.pmean(loss, axis_name='G')
# Each device performs its own update, but since we start with the same params
# and synchronise gradients, the params stay in sync
LEARNING_RATE = 1e-3
new_params = jax.tree_map(
lambda param, g: param - g * LEARNING_RATE, params, grads)
return new_params, loss
The same problem shows up. That is my GPUs utilization and training speed are both low.
Do you have any ideas about my issue?
WANG Jacques is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.