I have a custom layer in my VAE which handles the loss calculation. I have a custom callback which is meant to update the beta value at the beginning of every epoch. The calculation is correct because it prints out the desired beta value at the beginning of every epoch. However, the real beta value in the CustomVariationalLayer is never updated.
@register_keras_serializable('CustomVariationalLayer')
class CustomVariationalLayer(keras.layers.Layer):
def __init__(self, beta=1.0, **kwargs):
self.is_placeholder = True
super(CustomVariationalLayer, self).__init__(**kwargs)
self.beta = beta
self.recon_loss_metric = tf.keras.metrics.Mean(name='recon_loss')
self.kl_loss_metric = tf.keras.metrics.Mean(name='kl_loss')
def vae_loss(self, x, z_decoded, z_mean, z_log_var):
recon_loss = keras.losses.binary_crossentropy(K.flatten(x), K.flatten(z_decoded))
kl_loss = -0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
print(f"nReal beta value {self.beta:.4f}")
return recon_loss, self.beta * kl_loss
def call(self, inputs):
x = inputs[0]
z_decoded = inputs[1]
z_mean = inputs[2]
z_log_var = inputs[3]
recon_loss, kl_loss = self.vae_loss(x, z_decoded, z_mean, z_log_var)
self.add_loss(K.mean(recon_loss + kl_loss))
self.recon_loss_metric.update_state(recon_loss)
self.kl_loss_metric.update_state(kl_loss)
return x
def compute_output_shape(self, input_shape):
return input_shape[0]
def get_metrics(self):
return {'recon_loss': self.recon_loss_metric.result().numpy(),
'kl_loss': self.kl_loss_metric.result().numpy()}
class BetaAnnealing(keras.callbacks.Callback):
def __init__(self, layer, initial_beta=0.0, final_beta=1.0, epochs=100):
super(BetaAnnealing, self).__init__()
self.layer = layer
self.initial_beta = initial_beta
self.final_beta = final_beta
self.epochs = epochs
def on_epoch_begin(self, epoch, logs=None):
new_beta = self.initial_beta + (self.final_beta - self.initial_beta) * (epoch / (self.epochs - 1))
self.layer.beta = new_beta
print(f"Epoch {epoch+1}: Beta value updated to {new_beta:.4f}")
print(f"Layer beta value: {self.layer.beta}")
This is my training
batch_size = 128
epochs = 50
# Early stopping
es = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=5)
# Beta annealing
ba = BetaAnnealing(CustomVariationalLayer, initial_beta=0.0, final_beta=1.0, epochs=epochs)
# Compile the model
model.compile(optimizer='adam',
loss=None)
# Train the model
history = model.fit(x_train, x_train,
batch_size=batch_size,
epochs=epochs,
shuffle=True,
validation_data=(x_validate, x_validate),
callbacks=[es, ba])
It looked like it was working at first because the print statements from BetaAnnealing would be correct, but I added a print statement to the custom layer to verify, and the beta value was ‘1.0’ every time. this is a sample from my training output:
Epoch 1: Beta value updated to 0.0000
Layer beta value: 0.0
Epoch 1/50
Real beta value 1.0000
Real beta value 1.0000
53/53 ━━━━━━━━━━━━━━━━━━━━ 0s 112ms/step - kl_loss: 237.6640 - loss: 238.3435 - recon_loss: 0.6795
Real beta value 1.0000
53/53 ━━━━━━━━━━━━━━━━━━━━ 17s 162ms/step - kl_loss: 234.6529 - loss: 235.3318 - recon_loss: 0.6789 - val_kl_loss: 5.1546 - val_loss: 5.8057 - val_recon_loss: 0.6506
Epoch 2: Beta value updated to 0.0204
Layer beta value: 0.02040816326530612
Epoch 2/50
53/53 ━━━━━━━━━━━━━━━━━━━━ 2s 40ms/step - kl_loss: 7.7230 - loss: 8.3554 - recon_loss: 0.6325 - val_kl_loss: 922390720.0000 - val_loss: 922390720.0000 - val_recon_loss: 7.9802
Epoch 3: Beta value updated to 0.0408
Layer beta value: 0.04081632653061224
Epoch 3/50
53/53 ━━━━━━━━━━━━━━━━━━━━ 2s 39ms/step - kl_loss: 4.8922 - loss: 5.5163 - recon_loss: 0.6242 - val_kl_loss: 740397.3125 - val_loss: 740402.4375 - val_recon_loss: 4.9765
Epoch 4: Beta value updated to 0.0612
Layer beta value: 0.061224489795918366
Epoch 4/50
53/53 ━━━━━━━━━━━━━━━━━━━━ 2s 39ms/step - kl_loss: 1.3511 - loss: 1.9775 - recon_loss: 0.6265 - val_kl_loss: 1217.8662 - val_loss: 1218.7296 - val_recon_loss: 0.8520
Epoch 5: Beta value updated to 0.0816
Layer beta value: 0.08163265306122448
Epoch 5/50
53/53 ━━━━━━━━━━━━━━━━━━━━ 2s 41ms/step - kl_loss: 0.2503 - loss: 0.8771 - recon_loss: 0.6268 - val_kl_loss: 3.0466 - val_loss: 3.6738 - val_recon_loss: 0.6260
How would I get this to work?