I’m using JAX to evaluate a batched loss gradient that involves some complicated linear algebra (including Cholesky decompositions and solutions, etc.). The schematic form of my gradient loss is
jax.jit( jax.value_and_grad( jax.vmap(loss)(...).mean() ) )
I’m finding that the compilation/first-evaluation time is constant up to a certain batch size given to the vmap (as I would expect in general), and then starts growing superlinearly. On an A100, it’s something like 6 mins for nbatch <= 64, 13 mins for nbatch=128, 1 hour for nbatch=256, which becomes unwieldy.
What could be happening here? Would jax.vmap ever try to unroll the batch if it runs out of memory or compute units?
Michele Vallisneri is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.