Why does JAX compilation time grow with vmap batch size?
I’m using JAX to evaluate a batched loss gradient that involves some complicated linear algebra (including Cholesky decompositions and solutions, etc.). The schematic form of my gradient loss is