I’m using numpyro to try to sample several variables in a model, one of which is the number of iterations of a for loop. I have showed an analogous toy model here.
def model():
mu = 0.
sigma = numpyro.sample("sigma", dist.HalfNormal())
T = numpyro.sample("T", dist.DiscreteUniform(1, 5))
for i in range(1, T):
mu += 1.
logl = numpyro.deterministic("logl", normal_logl(data, mu, sigma))
numpyro.factor("log_likelihood", logl)
The above model raises an error. Replacing the for loop with jax.lax.fori_loop doesn’t help either. Is there a workaround?
New contributor
sangeetpaul is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.