I have two arrays say f and g, f is N by T by J dimensional and f is T by J dimensional. I’m trying to compute the following in JAX (for all 0<=t<T):
Notice that if t-a<0 I’d like it to default to 0.
What would be the fastest approach?
Right now I create a list of all possible indexes, multiply elementwise the two arrays evaluated in the relevant indexes and sum them up:
<code>import jax.numpy as jnp
all_indices = jnp.array([(θ, t, a) for θ in range(N) for t in range(T) for a in range(J)])
θ_idx, t_idx, a_idx = all_indices[:, 0], all_indices[:, 1], all_indices[:, 2]
tma_idx = jnp.maximum(t_idx - a_idx, 0)
unrolled = f[θ_idx, t_idx, a_idx] * g[tma_idx, a_idx]
s = unrolled.reshape(N, T, J).sum(axis=(0,2))
</code>
<code>import jax.numpy as jnp
all_indices = jnp.array([(θ, t, a) for θ in range(N) for t in range(T) for a in range(J)])
θ_idx, t_idx, a_idx = all_indices[:, 0], all_indices[:, 1], all_indices[:, 2]
tma_idx = jnp.maximum(t_idx - a_idx, 0)
unrolled = f[θ_idx, t_idx, a_idx] * g[tma_idx, a_idx]
s = unrolled.reshape(N, T, J).sum(axis=(0,2))
</code>
import jax.numpy as jnp
all_indices = jnp.array([(θ, t, a) for θ in range(N) for t in range(T) for a in range(J)])
θ_idx, t_idx, a_idx = all_indices[:, 0], all_indices[:, 1], all_indices[:, 2]
tma_idx = jnp.maximum(t_idx - a_idx, 0)
unrolled = f[θ_idx, t_idx, a_idx] * g[tma_idx, a_idx]
s = unrolled.reshape(N, T, J).sum(axis=(0,2))
This does not seem particularly efficient nor elegant and I would appreciate a better solution.