Let’s say I have a function f
that takes an integer argument and returns a fixed-size array. I want to evaluate the sum of f(i)
over range(N)
for some large N
, such that storing all values in memory becomes problematic. With a simple for
loop I can fix this easily and evaluate the sum with constant memory use:
import jax.numpy as jnp
f = lambda i : i*jnp.identity(1000)+i # A simple function that will quickly eat up memory
result = 0. # Generic initialization - works with any array-like function.
for i in range(N):
result += f(i)
but then the for
loop is very slow. On the other hand, if I write this using jax vmap
,
result = jnp.sum(vmap(f)(jnp.arange(N)),axis=0)
I’m also in trouble because all values are evaluated before the sum is done, and I’m eating up all the memory.
What would be the right way to vectorize this sum using jax? I’ve looked for it for a while, and couldn’t find an elegant solution.