I am building an autoencoder implemented as a custom tf.keras.Model. While the model after training performs well, I haven’t been able to save it and reload it properly. I have tried both model.save() method and save_weights() but in both case the model fails completely to perform its task.
This autoencoder is calling two others tf.keras.Model, the encoder and the decoder which in turn calls custom layers.
A residual convolution block:
@tf.keras.utils.register_keras_serializable(package="ae", name="ResidualConvBlock")
class ResidualConvBlock(tf.keras.Layer):
def __init__(self, n_filters: int, activation = 'relu', is_res = False, **kwargs) -> None:
super().__init__(**kwargs)
self.is_res = is_res
self.conv1 = tf.keras.layers.Conv2D(filters = n_filters, kernel_size = 3,
strides=1, kernel_initializer = 'he_normal', padding = 'same')
self.norm1 = tf.keras.layers.BatchNormalization()
self.activation1 = tf.keras.layers.Activation(activation)
self.conv2 = tf.keras.layers.Conv2D(filters = n_filters, kernel_size = 3,
strides=1, kernel_initializer = 'he_normal', padding = 'same')
self.norm2 = tf.keras.layers.BatchNormalization()
self.activation2 = tf.keras.layers.Activation(activation)
self.shortcut = tf.keras.layers.Conv2D(n_filters, kernel_size=1, strides=1, padding='valid')
def call(self, inputs, training=False):
# First convolutional layer
x1 = self.conv1(inputs)
x1 = self.norm1(x1)
x1 = self.activation1(x1)
# Second convolutional layer
x2 = self.conv2(x1)
x2 = self.norm2(x2)
out = self.activation2(x2)
if self.is_res:
if inputs.shape[-1] == out.shape[-1]:
out = inputs + out
else:
out = self.shortcut(inputs) + out
out = out / 1.414
return out
An encoder block:
@tf.keras.utils.register_keras_serializable(package="ae", name="EncoderBlock")
class EncoderBlock(tf.keras.Layer):
def __init__(self, n_filters=64, pool_size=(2,2), dropout=0.3, **kwargs):
super().__init__(**kwargs)
self.c = ResidualConvBlock(n_filters=n_filters)
self.p = tf.keras.layers.MaxPooling2D(pool_size=pool_size)
self.d = tf.keras.layers.Dropout(0.3)
def call(self, inputs):
c = self.c(inputs)
p = self.p(c)
d = self.d(p)
return d, c
The encoder model:
@tf.keras.utils.register_keras_serializable(package="ae", name="Encoder")
class Encoder(tf.keras.Model):
def __init__(self, latent_dim:int, n_filters: int, depth: int, **kwargs):
super().__init__(**kwargs)
self.n_filters = n_filters
self.depth = depth
self.enc_blocks = []
self.bottle_neck = tf.keras.layers.Dense(units = latent_dim)
for i in range(self.depth):
if i == 0:
self.enc_blocks.append(EncoderBlock(n_filters=self.n_filters, pool_size=(2,3)))
else:
self.enc_blocks.append(EncoderBlock(n_filters=2 ** i * self.n_filters))
def call(self,inputs):
convs = []
x = inputs
for block in self.enc_blocks:
x, c = block(x)
convs.append(c)
out = self.bottle_neck(x)
return out, convs
def build_graph(self, raw_shape):
x = tf.keras.Input(shape=raw_shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x))
def get_config(self):
base_config = super().get_config()
config = {
"n_filters": self.n_filters,
"depth": self.depth,
"EncoderBlock": tf.keras.legacy.saving.serialize_keras_object(self.enc_blocks[0])
}
return {**base_config, **config}
@classmethod
def from_config(cls, config):
EncoderBlock_config = config.pop("EncoderBlock")
EncoderBlock = tf.keras.legacy.saving.deserialize_keras_object(EncoderBlock_config)
return cls(EncoderBlock, **config)
# return cls(**config)
A decoder block:
@tf.keras.utils.register_keras_serializable(package="ae", name="DecoderBlock")
class DecoderBlock(tf.keras.Layer):
def __init__(self, n_filters=64, kernel_size=3, strides=(2,2), dropout=0.3, is_res = False, **kwargs):
super().__init__(**kwargs)
self.is_res = is_res
self.u = tf.keras.layers.Conv2DTranspose(n_filters, kernel_size, strides = strides, padding = 'same')
self.d = tf.keras.layers.Dropout(dropout)
self.c = ResidualConvBlock(n_filters=n_filters)
self.is_res = is_res
def call(self, inputs, conv):
u = self.u(inputs)
if self.is_res:
x = tf.keras.layers.concatenate([u, conv])
else:
x = u
x = self.d(x)
out = self.c(x)
return out
The decoder model:
@tf.keras.utils.register_keras_serializable(package="ae", name="Decoder")
class Decoder(tf.keras.Model):
def __init__(self, n_filters:int, depth:int = 4, output_channels:int =3, **kwargs):
super().__init__(**kwargs)
self.n_filters = n_filters
self.depth = depth
self.output_channels = output_channels
self.decoder_blocks = []
for i in range(depth):
if i == depth -1:
self.decoder_blocks.append(DecoderBlock(n_filters=2 ** (depth - i -1) * self.n_filters, strides = (2,3)))
else:
self.decoder_blocks.append(DecoderBlock(n_filters=2 ** (depth - i -1) * self.n_filters))
self.final_conv = tf.keras.layers.Conv2D(self.output_channels, (1, 1), activation='sigmoid')
def call(self, inputs, convs):
out = inputs
for i in range(self.depth):
out = self.decoder_blocks[i](out, convs[-i-1])
outputs = self.final_conv(out)
return outputs
def build_graph(self, raw_shape):
x = tf.keras.Input(shape=raw_shape)
y = []
for i in range(self.depth-1):
y.append(tf.keras.Input(shape=(raw_shape[0] * 2 ** (i+1), raw_shape[1] * 2 ** (i+1), int(self.n_filters * 2 ** (self.depth-i-1)))))
y.append(tf.keras.Input(shape=(raw_shape[1] * 2 ** (self.depth), raw_shape[0] * 2 ** (self.depth-1) * 3, int(self.n_filters))))
y.reverse()
return tf.keras.Model(inputs=[x], outputs=self.call(x, y))
def get_config(self):
base_config = super().get_config()
config = {
"n_filters": self.n_filters,
"depth": self.depth,
"output_channels": self.output_channels,
"DecoderBlock": tf.keras.legacy.saving.serialize_keras_object(self.decoder_blocks[0])
}
return {**base_config, **config}
@classmethod
def from_config(cls, config):
DecoderBlock_config = config.pop("DecoderBlock")
DecoderBlock = tf.keras.legacy.saving.deserialize_keras_object(DecoderBlock_config)
return cls(**config)
And finally the auto-encoder model:
@tf.keras.utils.register_keras_serializable(package="ae", name="AutoEncoder")
class AE_model(tf.keras.Model):
def __init__(self, n_filters: int, latent_dim: int, depth:int, **kwargs):
super().__init__(**kwargs)
self.latent_dim = latent_dim
self.depth = depth
# encoder
self.encoder = Encoder(n_filters=n_filters, latent_dim=latent_dim, depth=self.depth) #encoder(latent_dim, n_filters)
# decoder
self.decoder = Decoder(n_filters=n_filters, depth=self.depth, output_channels=3)
def call (self, inputs):
encoded, convs = self.encoder(inputs)
decoded = self.decoder(encoded, convs)
return decoded
def build_graph(self, raw_shape):
x = tf.keras.Input(shape=raw_shape)
return tf.keras.Model(inputs=[x], outputs=self.call(x))
def get_config(self):
base_config = super().get_config()
config = {
"latent_dim": self.latent_dim,
"depth": self.depth,
"n_filters": self.encoder.n_filters,
"encoder": tf.keras.legacy.saving.serialize_keras_object(self.encoder),
"decoder": tf.keras.legacy.saving.serialize_keras_object(self.decoder)
}
return {**base_config, **config}
@classmethod
def from_config(cls, config):
encoder_config = config.pop("encoder")
encoder = tf.keras.legacy.saving.deserialize_keras_object(encoder_config)
decoder_config = config.pop("decoder")
decoder = tf.keras.legacy.saving.deserialize_keras_object(decoder_config)
#return cls(encoder, decoder, **config)
return cls(**config)
To be able to save the Autoencoder and reload it without error I had to overwrite the get_config and from_config methods of the Encoder, Decoder and AE_model classes.
However I don’t understand why in the case of the Encoder I had to return the config with the Encoder class deserialized, otherwise it would complain that the Encoder was unknown, while for the Decoder of AE_model, it complains that element such as n_filters are defined multiple times.
With the above configure I am able to save a model and reload it. However the reconstructed image from the autoencoder are completely grey.
Result model after training:
enter image description here
Result model after training, save and reload:
enter image description here
Question:
How can I save the model or its weights and be able to reload it for future inference or additional training?