On a full production scale this seems to be the root cause of GPU OOM
error which happens during the first epoch. See fully working code and logs below, why is Epoch 1/3
going through the dataset twice and not behaving like epochs 2 & 3?
Also, is there a way to verify that batches are getting distributed/split correctly across all the device after calling model.fit
? Looping through dist_dataset
I’m able to see a PerReplica
distribution which looks correct but I’d like to verify the behaviour inside model.fit
because I get different results if I don’t use experimental_distribute_dataset
which according to Tensorflow and Keras docs is not strictly needed.
import tensorflow as tf
# Set up virtual devices
N_VIRTUAL_DEVICES = 4
physical_devices = tf.config.list_physical_devices("CPU")
tf.config.set_logical_device_configuration(
physical_devices[0], [tf.config.LogicalDeviceConfiguration() for _ in range(N_VIRTUAL_DEVICES)]
)
# Temporary uses `ReductionToOneDevice` to avoid runtime errors with `NcclAllReduce`
strategy = tf.distribute.MirroredStrategy(cross_device_ops=tf.distribute.ReductionToOneDevice())
print("Number of devices: {}".format(strategy.num_replicas_in_sync))
# Set up dataset
dataset_size = 36
batch_size = 3 * strategy.num_replicas_in_sync
dataset = tf.data.Dataset.range(dataset_size)
dataset = dataset.map(lambda x: (x, 2 * x))
dataset = dataset.batch(batch_size)
# Enable sharding via data
options = tf.data.Options()
options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.DATA
dataset = dataset.with_options(options)
print("Looping through dataset")
for data in dataset:
print(data)
steps_per_epoch = int(dataset_size / batch_size)
dataset = dataset.repeat()
# Debug dataset through `model.fit` with `tf.print`
def inspect_batch(features, labels):
tf.print("Batch inspection:", features, len(features))
return features, labels
dataset = dataset.map(inspect_batch)
# Create distributed dataset
dist_dataset = strategy.experimental_distribute_dataset(dataset)
with strategy.scope():
model = tf.keras.Sequential(
[tf.keras.layers.Dense(128, activation="relu", input_shape=(1,)), tf.keras.layers.Dense(1)]
)
model.compile(optimizer="adam", loss="mean_squared_error")
class InspectBatchCallback(tf.keras.callbacks.Callback):
def on_train_batch_begin(self, batch, logs=None):
print(f"Batch {batch} starting.")
def on_train_batch_end(self, batch, logs=None):
print(f"Batch {batch} ended.")
print("Running model.fit and inspecting batches:")
model.fit(dist_dataset, epochs=3, steps_per_epoch=steps_per_epoch, callbacks=[InspectBatchCallback()])
Logs:
Number of devices: 4
Looping through dataset
(<tf.Tensor: shape=(12,), dtype=int64, numpy=array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])>, <tf.Tensor: shape=(12,), dtype=int64, numpy=array([ 0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22])>)
(<tf.Tensor: shape=(12,), dtype=int64, numpy=array([12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23])>, <tf.Tensor: shape=(12,), dtype=int64, numpy=array([24, 26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46])>)
(<tf.Tensor: shape=(12,), dtype=int64, numpy=array([24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35])>, <tf.Tensor: shape=(12,), dtype=int64, numpy=array([48, 50, 52, 54, 56, 58, 60, 62, 64, 66, 68, 70])>)
Running model.fit and inspecting batches:
Epoch 1/3
Batch 0 starting.
Batch inspection: [0 1 2 ... 9 10 11] 12
Batch inspection: [12 13 14 ... 21 22 23] 12
Batch inspection: [24 25 26 ... 33 34 35] 12
Batch inspection: [0 1 2 ... 9 10 11] 12
Batch 0 ended.
1/3 [=========>....................] - ETA: 4s - loss: 212.5091Batch 1 starting.
Batch inspection: [12 13 14 ... 21 22 23] 12
Batch 1 ended.
Batch 2 starting.
Batch inspection: [24 25 26 ... 33 34 35] 12
Batch 2 ended.
3/3 [==============================] - 2s 10ms/step - loss: 2044.8993
Epoch 2/3
Batch 0 starting.
Batch inspection: [0 1 2 ... 9 10 11] 12
Batch 0 ended.
1/3 [=========>....................] - ETA: 0s - loss: 204.7775Batch 1 starting.
Batch inspection: [12 13 14 ... 21 22 23] 12
Batch 1 ended.
Batch 2 starting.
Batch inspection: [24 25 26 ... 33 34 35] 12
Batch 2 ended.
3/3 [==============================] - 0s 5ms/step - loss: 1983.4829
Epoch 3/3
Batch 0 starting.
Batch inspection: [0 1 2 ... 9 10 11] 12
Batch 0 ended.
1/3 [=========>....................] - ETA: 0s - loss: 198.4120Batch 1 starting.
Batch inspection: [12 13 14 ... 21 22 23] 12
Batch 1 ended.
Batch 2 starting.
Batch inspection: [24 25 26 ... 33 34 35] 12
Batch 2 ended.
3/3 [==============================] - 0s 5ms/step - loss: 1924.0229