Given a matrix m
with shape (n, n)
, I need to compute the series of “offset traces” [np.trace(m, offset=i) for i in range(q)]
in JAX. For my application, n
>> q
, and q
is a static parameter.
The obvious JAX approach using vmap
does not work, possibly because although the trace has fixed output size, each offset diagonal has a different length?
I came up with two other approaches using JAX which work but are about 100x slower than NumPy. get_traces_jax_1
is the more efficient of the two. But it does a lot of extra work when I only need a few diagonals, and I don’t think that extra work gets compiled away.
Is there a more efficient way to do this in JAX with similar performance to NumPy? I want to use JAX because:
- I need to
vmap
this across many matrices; - It is part of a larger algorithm, other parts of which are significantly sped up by JAX jit.
Below are the methods I explored and timings on my computer.
import numpy as np
from numpy import random
import jax
jax.config.update("jax_enable_x64", True) # default is float32
from jax import numpy as jnp
from functools import partial
n, q = 1000, 5
# check the methods produce the same result
def distance(u, v):
return jnp.max(jnp.abs(u - v))
# numpy - this is what I want
def get_traces_np(mat, q):
return np.array([np.trace(mat, offset=i) for i in range(q)])
# jax
# !! This does not work
@partial(jax.jit, static_argnums=(1,))
def get_traces_jax_broken(mat, q):
return jax.vmap(lambda i: jnp.trace(mat, offset=i))(jnp.arange(q)) # !! does not work
@partial(jax.jit, static_argnums=(1,))
def get_traces_jax_0(mat, q):
return jnp.array([jnp.trace(mat, offset=i) for i in range(q)])
@partial(jax.jit, static_argnums=(1,))
def get_traces_jax_1(mat, q):
n = mat.shape[0]
padded = jnp.pad(mat, ((0, 0), (0, n-1)), 'constant')
shifts = jax.vmap(lambda v, i: jnp.roll(v, -i))(padded, jnp.arange(n))[:, :n]
return jnp.sum(shifts, axis=0)[:q]
mat = random.uniform(size=(n, n))
# Check they produce the same result and precompile
d0 = distance(get_traces_np(mat, q), get_traces_jax_0(mat, q))
d1 = distance(get_traces_np(mat, q), get_traces_jax_1(mat, q))
print(f'Errors: {d0}, {d1}')
mat = jnp.array(mat)
print('Numpy:')
%timeit get_traces_np(mat, q) # 7.43 microseconds
print('Jax 0:')
%timeit get_traces_jax_0(mat, q) # 4.82ms
print('Jax 1:')
%timeit get_traces_jax_1(mat, q) # 1.22ms
ruai is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.