I am making a model for supervised poisson matrix factorization, the plate model is as follows:
My final goal is to get the posterior of theta, beta, xi, ni, nu, and gamma using SVI in pyro. So I have to write a guide function, My observed variables are X, X’ (X_aux in the code), and Y.
Now I have made the code for the model and I can generate the data with expected dimensions for variables.
But I have problems with the guide function.
This is my model function:
def model(x_data, x_aux_data, y_data):
n = 200
p = 50
k = 3
d = 20
p_aux = 10
sigma = torch.tensor(1, dtype=torch.float) # Standard deviation for the Normal distributions
mu = torch.tensor(0, dtype=torch.float) # mean for the Normal distributions
a = torch.tensor(2, dtype=torch.float)
a_prime = torch.tensor(2, dtype=torch.float)
b_prime = torch.tensor(1, dtype=torch.float)
c = torch.tensor(2, dtype=torch.float)
c_prime = torch.tensor(2, dtype=torch.float)
d_prime = torch.tensor(1, dtype=torch.float)
features_axis = pyro.plate("features_axis", p_aux, dim=-5)
module_axis = pyro.plate("module_axis", k, dim=-4)
sample_axis = pyro.plate("sample_axis", n, dim=-3)
gene_axis = pyro.plate("gene_axis", p, dim=-2)
latent_factors_axis = pyro.plate("latent_factors_axis", d, dim=-1)
with gene_axis:
ni = pyro.sample('ni', dist.Gamma(c_prime, d_prime))
with latent_factors_axis, gene_axis:
beta = pyro.sample('beta', dist.Gamma(c, ni))
with sample_axis:
xi = pyro.sample('xi', dist.Gamma(a_prime, b_prime))
with latent_factors_axis, sample_axis:
theta = pyro.sample('theta', dist.Gamma(a, xi))
rate = torch.matmul(theta, beta.T)
x_squeezed = rate.squeeze(1)
x = pyro.sample('x', dist.Poisson(x_squeezed), obs=x_data)
with latent_factors_axis, module_axis:
nu = pyro.sample('nu', dist.Normal(mu, sigma))
with features_axis, module_axis:
gamma = pyro.sample('gamma', dist.Normal(mu, sigma))
with sample_axis, features_axis:
x_aux = pyro.sample('x_aux', dist.Normal(mu, sigma), obs=x_aux_data)
# x_aux = pyro.sample('x_aux', dist.Normal(mu, sigma))
bern_term1 = torch.matmul(theta.squeeze(1), nu.squeeze(1, 2).T)
bern_term2 = torch.matmul(x_aux.squeeze(1, 3, 4).T, gamma.squeeze(2, 3, 4))
ben_arg = bern_term1 + bern_term2
module_probs = torch.sigmoid(ben_arg)
y = pyro.sample('y', dist.Bernoulli(probs=module_probs), obs=y_data)
# y = pyro.sample('y', dist.Bernoulli(probs=module_probs))
return x, x_aux, y
and this is my guide function:
def guide(x_data, x_aux_data, y_data):
n = 200
p = 50
k = 3
d = 20
p_aux = 10
# Define variational parameters for ni
c_prime_q = pyro.param('c_prime_q', torch.ones(p, dtype=torch.float), constraint=dist.constraints.positive)
d_prime_q = pyro.param('d_prime_q', torch.ones(p, dtype=torch.float), constraint=dist.constraints.positive)
ni = pyro.sample('ni', dist.Gamma(c_prime_q, d_prime_q))
# Define variational parameters for beta
c_q = pyro.param('c_q', torch.ones(d, p, dtype=torch.float), constraint=dist.constraints.positive)
beta = pyro.sample('beta', dist.Gamma(c_q, ni))
# Define variational parameters for xi
a_prime_q = pyro.param('a_prime_q', torch.ones(n, dtype=torch.float), constraint=dist.constraints.positive)
b_prime_q = pyro.param('b_prime_q', torch.ones(n, dtype=torch.float), constraint=dist.constraints.positive)
xi = pyro.sample('xi', dist.Gamma(a_prime_q, b_prime_q))
# Define variational parameters for theta
a_q = pyro.param('a_q', torch.ones(d, n, dtype=torch.float), constraint=dist.constraints.positive)
theta = pyro.sample('theta', dist.Gamma(a_q, xi))
# Define variational parameters for nu
nu_mu = pyro.param('nu_mu', torch.zeros(d, k, dtype=torch.float))
nu_scale = pyro.param('nu_scale', torch.ones(d, k, dtype=torch.float), constraint=dist.constraints.positive)
nu = pyro.sample('nu', dist.Normal(nu_mu, nu_scale))
# Define variational parameters for gamma
gamma_mu = pyro.param('gamma_mu', torch.zeros(p_aux, k, dtype=torch.float))
gamma_scale = pyro.param('gamma_scale', torch.ones(p_aux, k, dtype=torch.float),
constraint=dist.constraints.positive)
gamma = pyro.sample('gamma', dist.Normal(gamma_mu, gamma_scale))
I generate data using a the same model, and I use the following code to use SVI:
optimizer = pyro.optim.Adam({"lr": 0.01})
# Setup the inference algorithm
svi = pyro.infer.SVI(model=model,
guide=guide,
optim=optimizer,
loss=pyro.infer.Trace_ELBO())
# Number of iterations
num_iterations = 1000
# Run inference
for i in range(num_iterations):
loss = svi.step(x_data, x_aux_data, y_data)
if i % 100 == 0:
print(f"Iteration {i} : Loss {loss}")
But I get the following error:
ValueError: Shape mismatch inside plate(‘latent_factors_axis’) at site beta dim -1, 20 vs 50
(pyro version: 1.9.0, python: 3.10)
I suspect the guide model is incomplete, because in the model I manipulate the dimensions of matrices so that I can multiply them, but I am not sure how I should reflect them in the guide function.
I would appreciate if anybody know how can I tackle this problem.
I tried to remove the supervised part, but I still have problems.