I’m encountering an error while using a custom BatchNormalization layer in TensorFlow for an object detection project. The error message I’m getting is:
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: Exception encountered when calling BatchNormalization.call().
Using a symbolic `tf.Tensor` as a Python `bool` is not allowed.
The error seems to be originating from the BatchNormalization
class’s call
method. Here’s the relevant part of my custom BatchNormalization layer:
import tensorflow as tf
class BatchNormalization(tf.keras.layers.BatchNormalization):
"""
Make trainable=False freeze BN for real (the og version is sad)
"""
def call(self, x, training=False):
if training is None:
training = False
training = tf.logical_and(training, self.trainable)
return super().call(x, training)
I’ve tried to modify the code to handle the None
value for the training
argument using tf.cond
, like so:
training = tf.cond(tf.equal(training, None), lambda: tf.constant(False), lambda: training)
However, I’m still receiving the same error. Can anyone help me understand why I’m encountering this error and how to resolve it?