I am unsure what is the best way to vectorize objects in Python Jax.
In particular, I want to write a code that handles both calling a method from a single instantiation of a class and from multiple (vectorized) instantiations of the class.
In the following, I write a simple example of what I would like to achieve.
import jax
import jax.numpy as jnp
import jax.random as random
class Dummy:
def __init__(self, x, key):
self.x = x
self.key = key
def to_pytree(self):
return (self.x, self.key), None
def get_noisy_x(self):
self.key, subkey = random.split(self.key)
return self.x + random.normal(subkey, self.x.shape)
@staticmethod
def from_pytree(auxiliary, pytree):
return Dummy(*pytree)
jax.tree_util.register_pytree_node(Dummy,
Dummy.to_pytree,
Dummy.from_pytree)
The class Dummy
contains some information, x
and keys
, and has a method, get_noisy_x
. The following code works as expected:
key = random.PRNGKey(0)
dummy = Dummy(jnp.array([1., 2., 3.]), key)
dummy.get_noisy_x()
I would like get_noisy_x
to work also on a vectorized version of the object Dummy
.
key = random.PRNGKey(0)
key, subkey = random.split(key)
key_batch = random.split(subkey, 100)
dummy_vmap = jax.vmap(lambda x: Dummy(jnp.array([1., 2., 3.]), x))(key_batch)
I would expect dummy_vmap
to be an array of Dummy
objects; however, instead, dummy_vmap
results to be only one Dummy
with vectorized x
and key
. This is not ideal for me because that modifies the behavior of the code. For example, if I call dummy_vmap.get_noisy_x()
, I get returned an error saying that self.key, subkey = random.split(self.key)
does not work because self.key is not a single key. While this error could be solved in several ways – and actually, in this example, vectorization is not really needed, my goal is to understand how to write code in a object-oriented way, that both handles correctly
dummy = Dummy(jnp.array([1., 2., 3.]), key)
dummy.get_noisy_x()
and
vectorized_dummy = .... ?
vectorized_dummy.get_noisy_x()
Notice that the example that I have made could work in several ways without involving vectorization. What I look for, however, is a more generic way to deal with vectorization in much more complicated scenarios.
Update
I have found out that I need to vectorize get_noisy_x as well.
dummy_vmap = jax.vmap(lambda x: Dummy(jnp.array([1., 2., 3.]), x))(key_batch)
jax.vmap(lambda self: Dummy.get_noisy_x(self))(dummy_vmap) # this function call works exactly as expected.
However, this solution seems a bit counter-intuitive, and not really scalable, as in a larger project I would need to vectorize all functions of interest.
I would expect
dummy_vmap
to be an array ofDummy
objects; however, instead,dummy_vmap
results to be only oneDummy
with vectorizedx
andkey
.
Your expectation here is out of line with how JAX approaches vectorization: JAX uses a struct-of-arrays pattern rather than an array-of-structs pattern. This should work entirely seamlessly with your existing object, so long as you never explicitly construct a vectorized object; for example, you could do something like this:
def apply_dummy(x, key):
return Dummy(x, key).get_noisy_x()
key = random.key(0)
key, subkey = random.split(key)
key_batch = random.split(subkey, 100)
x = jnp.array([1., 2., 3.])
out_single = apply_dummy(x, key)
print(out_single.shape) # (3,)
out_batch = jax.vmap(apply_dummy, in_axes=(None, 0))(x, key_batch)
print(out_batch.shape) # (100, 3)
If you want to construct a vectorized dummy object, you can do so by applying vmap
to its constructor:
vectorized_dummy = jax.vmap(Dummy, in_axes=(None, 0))(x, key_batch)
However, as you found, this will not work correctly with your Dummy
object as it’s currently defined, because its methods are not batch-aware. The general approach here would be to modify _get_noisy_x
so that it does the appropriate thing when self.key
and self.x
are batched. The details will depend on assumptions you want to make: for example, if both key
and x
have a batch dimension, do you vectorize over both simultaneously, or do you return the outer-product? The answer, and therefore the implementation, will depend on information not provided in your question.
Also, as a side note, the way this method is defined will generally be problematic in JAX:
def get_noisy_x(self):
self.key, subkey = random.split(self.key)
return self.x + random.normal(subkey, self.x.shape)
The issue is that it is impure: calling the function results in mutating self
in-place (changing the value of self.key
). Functions with side-effects like this may not behave as you expect when used with JAX transformations like jit
, vmap
, or grad
: for example,. For a discussion of these issues, see JAX Sharp Bits: Pure Functions.
As a demonstration of this, take a look at the value of dummy.key
before and after running the code under your Update:
dummy_vmap = jax.vmap(lambda x: Dummy(jnp.array([1., 2., 3.]), x))(key_batch)
print(dummy_vmap.key)
jax.vmap(Dummy.get_noisy_x)(dummy_vmap)
print(dummy_vmap.key) # unchanged!
The fix would be to not rely on this kind of side-effect in your code.
2