I’m trying to implement the Hutchinson-Skilling estimator (later used in the context of diffusion models) in Jax, but I’m having difficuly understanding why even in some very basic cases, I need so many gaussian samples to have an accurate estimator.
I tried testing it on this super simple example with a function of 2 arguments only, the square of each argument. However, the error is quite random and I need around 10k noise samples to get it down to less than 1% in most cases (depends on the rng key). What’s wrong ?
import jax
import jax.numpy as jnp
from jax import random
# Define the function f(x) = x**2
def f(x):
return x**2
# Define the Hutchinson-Skilling estimator for divergence using Gaussian vectors
def get_div_fn(f,step_rng,num_samples=1000):
key = step_rng
def hutchinson_skilling_divergence(x):
def single_sample_divergence(x, key):
v = jax.random.normal(key, shape=x.shape)
jacobian_vector_product = jax.jvp(f, (x,), (v,))[1]
return jnp.dot(v, jacobian_vector_product)
keys = jax.random.split(key, num_samples)
divergence_estimates = jax.vmap(single_sample_divergence, in_axes=(None, 0))(x, keys)
return jnp.mean(divergence_estimates)
return hutchinson_skilling_divergence
input = jnp.array([1.0, 1.0])
# Random key for JAX
key = random.PRNGKey(42)
f_div = get_div_fn(f,key)
result = f_div(input)
print(result)