I encountered an error while trying to build a GAN model. Here’s the code snippet where the error occurs:
code snippet
np.random.seed(SEED)
for epoch in range(10):
for batch in tqdm(range(STEPS_PER_EPOCH)):
# Generate fake images
noise = np.random.normal(0, 1, size=(BATCH_SIZE, NOISE_DIM))
fake_X = generator.predict(noise)
# Select a random batch of real images
idx = np.random.randint(0, X_train.shape[0], size=BATCH_SIZE)
real_X = X_train[idx]
# Reshape fake images to match the shape of real images
fake_X_reshaped = np.transpose(fake_X, axes=(0, 2, 1, 3))
# Concatenate real and fake images into a single batch
X = np.concatenate((real_X, fake_X_reshaped))
# Create labels for the discriminator
disc_y = np.zeros(2 * BATCH_SIZE)
disc_y[:BATCH_SIZE] = 1
# Train the discriminator on the batch
d_loss = discriminator.train_on_batch(X, disc_y)
# Generate new noise for the generator
noise = np.random.normal(0, 1, size=(BATCH_SIZE, NOISE_DIM))
# Create labels for the generator (trick the discriminator)
y_gen = np.ones(BATCH_SIZE)
# Train the generator to fool the discriminator
g_loss = gan.train_on_batch(noise, y_gen)
print(f"EPOCH: {epoch + 1} Generator Loss: {g_loss:.4f} Discriminator Loss: {d_loss:.4f}")
noise = np.random.normal(0, 1, size=(10, NOISE_DIM))
sample_images(noise, (2, 5))
error message
0%| | 0/3750 [00:00<?, ?it/s]
WARNING:tensorflow:6 out of the last 16 calls to <function Model.make_predict_function.<locals>.predict_function at 0x0000023093B2E480> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has reduce_retracing=True option that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
1/1 [==============================] - 1s 587ms/step
0%| | 0/3750 [00:02<?, ?it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
Cell In[138], line 33
30 y_gen = np.ones(BATCH_SIZE)
32 # Train the generator to fool the discriminator
---> 33 g_loss = gan.train_on_batch(noise, y_gen)
35 print(f"EPOCH: {epoch + 1} Generator Loss: {g_loss:.4f} Discriminator Loss: {d_loss:.4f}")
36 noise = np.random.normal(0, 1, size=(10, NOISE_DIM))
File ~AppDataRoamingPythonPython311site-packageskerassrcenginetraining.py:2787, in Model.train_on_batch(self, x, y, sample_weight, class_weight, reset_metrics, return_dict)
2783 iterator = data_adapter.single_batch_iterator(
2784 self.distribute_strategy, x, y, sample_weight, class_weight
2785 )
2786 self.train_function = self.make_train_function()
-> 2787 logs = self.train_function(iterator)
2789 logs = tf_utils.sync_to_numpy_or_python_type(logs)
2790 if return_dict:
File ~AppDataRoamingPythonPython311site-packagestensorflowpythonutiltraceback_utils.py:153, in filter_traceback.<locals>.error_handler(*args, **kwargs)
151 except Exception as e:
152 filtered_tb = _process_traceback_frames(e.__traceback__)
--> 153 raise e.with_traceback(filtered_tb) from None
154 finally:
155 del filtered_tb
File C:UsersMOHAME~1AppDataLocalTemp__autograph_generated_file_1d1u69b.py:15, in outer_factory.<locals>.inner_factory.<locals>.tf__train_function(iterator)
13 try:
14 do_return = True
---> 15 retval_ = ag__.converted_call(ag__.ld(step_function), (ag__.ld(self), ag__.ld(iterator)), None, fscope)
16 except:
17 do_return = False
File ~AppDataRoamingPythonPython311site-packageskerassrcenginetraining.py:1384, in Model.make_train_function.<locals>.step_function(model, iterator)
1380 run_step = tf.function(
1381 run_step, jit_compile=True, reduce_retracing=True
1382 )
1383 data = next(iterator)
-> 1384 outputs = model.distribute_strategy.run(run_step, args=(data,))
1385 outputs = reduce_per_replica(
1386 outputs,
1387 self.distribute_strategy,
1388 reduction=self.distribute_reduction_method,
1389 )
1390 return outputs
File ~AppDataRoamingPythonPython311site-packageskerassrcenginetraining.py:1373, in Model.make_train_function.<locals>.step_function.<locals>.run_step(data)
1372 def run_step(data):
-> 1373 outputs = model.train_step(data)
1374 # Ensure counter is updated only if `train_step` succeeds.
1375 with tf.control_dependencies(_minimum_control_deps(outputs)):
File ~AppDataRoamingPythonPython311site-packageskerassrcenginetraining.py:1150, in Model.train_step(self, data)
1148 # Run forward pass.
1149 with tf.GradientTape() as tape:
-> 1150 y_pred = self(x, training=True)
1151 loss = self.compute_loss(x, y, y_pred, sample_weight)
1152 self._validate_target_and_loss(y, loss)
File ~AppDataRoamingPythonPython311site-packageskerassrcutilstraceback_utils.py:70, in filter_traceback.<locals>.error_handler(*args, **kwargs)
67 filtered_tb = _process_traceback_frames(e.__traceback__)
68 # To get the full stack trace, call:
69 # `tf.debugging.disable_traceback_filtering()`
---> 70 raise e.with_traceback(filtered_tb) from None
71 finally:
72 del filtered_tb
File ~AppDataRoamingPythonPython311site-packageskerassrcengineinput_spec.py:298, in assert_input_compatibility(input_spec, inputs, layer_name)
296 if spec_dim is not None and dim is not None:
297 if spec_dim != dim:
--> 298 raise ValueError(
299 f'Input {input_index} of layer "{layer_name}" is '
300 "incompatible with the layer: "
301 f"expected shape={spec.shape}, "
302 f"found shape={display_shape(x.shape)}"
303 )
ValueError: in user code:
File "C:UsersMohamed WalidAppDataRoamingPythonPython311site-packageskerassrcenginetraining.py", line 1401, in train_function *
return step_function(self, iterator)
File "C:UsersMohamed WalidAppDataRoamingPythonPython311site-packageskerassrcenginetraining.py", line 1384, in step_function **
outputs = model.distribute_strategy.run(run_step, args=(data,))
File "C:UsersMohamed WalidAppDataRoamingPythonPython311site-packageskerassrcenginetraining.py", line 1373, in run_step **
outputs = model.train_step(data)
File "C:UsersMohamed WalidAppDataRoamingPythonPython311site-packageskerassrcenginetraining.py", line 1150, in train_step
y_pred = self(x, training=True)
File "C:UsersMohamed WalidAppDataRoamingPythonPython311site-packageskerassrcutilstraceback_utils.py", line 70, in error_handler
raise e.with_traceback(filtered_tb) from None
File "C:UsersMohamed WalidAppDataRoamingPythonPython311site-packageskerassrcengineinput_spec.py", line 298, in assert_input_compatibility
raise ValueError(
ValueError: Exception encountered when calling layer 'gan_model' (type Functional).
Input 0 of layer "discriminator" is incompatible with the layer: expected shape=(None, 208, 176, 1), found shape=(4, 176, 208, 1)
Call arguments received by layer 'gan_model' (type Functional):
• inputs=tf.Tensor(shape=(4, 100), dtype=float32)
• training=True
• mask=None
I’m attempting to train a Generative Adversarial Network (GAN) model. The goal is to generate fake images using the generator model and then train the discriminator to distinguish between real and fake images. Here’s a brief overview of the process:
Generating fake images using the generator model.
Selecting a random batch of real images from the dataset.
Concatenating real and fake images into a single batch.
Creating labels for the discriminator.
Training the discriminator on the batch.
Generating new noise for the generator.
Creating labels for the generator to fool the discriminator.
Training the generator.
Additional Information:
The code is implemented in Python using Keras.
I’m using a specific architecture for the generator and discriminator models.
I’ve checked the shapes of input data and they seem to be correct.
The error occurs during the training loop, specifically when calling gan.train_on_batch(noise, y_gen).
Any insights into what might be causing this error would be greatly appreciated. Thank you!
Mohamed Waleed is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.