I am following the Gaussian Process example in Numpyro Website : https://num.pyro.ai/en/latest/examples/gp.html
When I run MCMC I get the following Error TypeError: 'NUTS' object is not callable
I havent included the get_data
function as they are in the link above. But X,Y are (30,)
in shape and of type jaxlib.xla_extension.ArrayImpl
and the kernel function returns a array of shape (30,30)
and of type jaxlib.xla_extension.ArrayImpl
.
Any clue what the error is ?
# Kernal Function
def kernel(X,Z,var,length, noise,jitter=1.0e-6, include_noise= True):
deltaXsq = jnp.power((X[:,None] - Z) /length, 2.0)
k = var*jnp.exp(-0.5 * deltaXsq)
if include_noise:
k += (noise + jitter) *np.eye(X.shape[0])
return k
# Model
def model(X,Y):
# uninformative log-normal hyperpriors
var = numpyro.sample("kernel_var", dist.LogNormal(0.0,10.0))
noise = numpyro.sample("noise", dist.LogNormal(0.0,10.0))
length = numpyro.sample("length", dist.LogNormal(0.0,10.0))
# compute kernel
k = kernel(X,X,var,length, noise)
# sample Y according to the standard gaissian process formulae
numpyro.sample(
"Y",
dist.MultivariateNormal(loc = jnp.zeros(X.shape[0]), covariance_matrix=k),
obs = Y
)
if __name__ == "__main__":
X,Y,X_test = get_data() #(30,),(30,),(400,)
#k = kernel(X = X,Z = X,var = VAR,length = LENGTH,noise = NOISE)
NUM_WARMUP = 1000
NUM_CHAINS = 1
THINNING = 2
NUM_DATA = 25
DEVICE = "cpu"
NUM_SAMPLES = 10
INIT_STRATEGY = init_to_sample()
KEY = random.PRNGKey(0)
k0,k1 = random.split(KEY)
kernel = NUTS(model, init_strategy=INIT_STRATEGY)
mcmc = MCMC(
kernel,
num_warmup=NUM_WARMUP,
num_samples = NUM_SAMPLES,
num_chains = NUM_CHAINS,
thinning = THINNING,
progress_bar = True
)
mcmc.run(k1, X,Y)
Full Error Message
Traceback (most recent call last):
File "/Users/imantha/workspace/python/tutorials/jax/gp.py", line 82, in <module>
mcmc.run(k1, X.reshape(-1,1),Y.reshape(-1,1))
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 682, in run
states_flat, last_state = partial_map_fn(map_args)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/infer/mcmc.py", line 443, in _single_chain_mcmc
new_init_state = self.sampler.init(
^^^^^^^^^^^^^^^^^^
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 749, in init
init_params = self._init_state(
^^^^^^^^^^^^^^^^^
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/infer/hmc.py", line 693, in _init_state
) = initialize_model(
^^^^^^^^^^^^^^^^^
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/infer/util.py", line 662, in initialize_model
) = _get_model_transforms(substituted_model, model_args, model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/infer/util.py", line 456, in _get_model_transforms
model_trace = trace(model).get_trace(*model_args, **model_kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/handlers.py", line 171, in get_trace
self(*args, **kwargs)
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/imantha/tools/miniforge3/envs/ml3.12/lib/python3.12/site-packages/numpyro/primitives.py", line 105, in __call__
return self.fn(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/imantha/workspace/python/tutorials/jax/gp.py", line 47, in model
k = kernel(X,X,var,length, noise)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: 'NUTS' object is not callable
1