Problem
We are implementing a VAE from scratch but are allowed to use the torch.Tensor
data structure for GPU capabilities. This is a school assignment and we seem to be in over our heads. We’ve looked for guides and tutorials on VAEs, but every single one uses PyTorch/TensorFlow, which obfuscates the back-propagation process completely. We are looking for help in getting the back-propagation process started, because it seems like once you get it started, it follows naturally backwards through your network.
Architecture
Encoder – there is a split after step 4, where steps 5 – 7 are done both for latent_mean
and latent_log
- Input (419, 419)
- Flatten
- Dense(419 * 419, 256)
- LeakyRelu
- Dense(256, 100)
- LeakyRelu
- Normalize
Decoder
- Dense(100, 256)
- LeakyRelu
- Normalize
- Dense(256, 419 * 419)
- Normalize
- Reshape
Flow
# Our current implementation does a forward/backward pass for each sample
# However, weights are updated after each batch
latent_mean, latent_log = encoder.forward(x)
latent_vector = reparameterize(latent_mean, latent_log)
reconstructed = decoder.forward(latent_vector)
reconstruction_loss = torch.mean((x - reconstructed) ** 2)
kl_divergence = -0.5 * torch.mean(1 + latent_log - latent_mean ** 2 - torch.exp(latent_log))
loss = reconstruction_loss + kl_divergence
And this is where we are stuck. From class it seems like we are supposed to calculate the derivative of our loss function with respect to the weights of the output neurons from the decoder. Our class notes also suggest that the output layer derivations are in general different than hidden layer derivations. However, a lot of examples of neural networks online don’t seem to make this distinction. I know we are using a well-known loss function, but I am unsure how to differentiate it or what the loss function is even called (we found it in a random blog post and have seen it in multiple since).
I’ve seen other examples online begin the process of backpropagation by multiplying the loss/error by the application of the derivative of the activation function to the output activation, or
delta = error * output_activation_function_derivative(output_activation)
If we can figure out how to calculate the initial delta, I think we can follow the network backwards and update the rest of the weights with little issue.