I’m working on an image segmentation task and trying to use a pre-trained Swin Transformer Large (Swin-L) encoder for the feature extraction backbone. The code runs perfectly on a CPU in Colab. However, when switching to a TPU, it throws the error shown below.
The Code:
from tensorflow.keras import layers, Model, Input
from tfswin import SwinTransformerLarge224
def load_swin_encoder(input_shape=(512, 512, 3)):
# Load pre-trained Swin-L model
swin_encoder = SwinTransformerLarge224(include_top=False, weights='imagenet',
input_shape=input_shape)
# Freeze the pre-trained layers
for layer in swin_encoder.layers:
layer.trainable = False
# Extract outputs from the four stages
stage_outputs = [
swin_encoder.get_layer('normalize').output, # Output from the 0 stage
swin_encoder.get_layer('layers.0').output, # Output from the first stage
swin_encoder.get_layer('layers.1').output, # Output from the second stage
swin_encoder.get_layer('layers.2').output, # Output from the third stage
swin_encoder.get_layer('layers.3').output, # Output from the fourth stage
]
return Model(swin_encoder.input, stage_outputs, name="SwinTransformerEncoder")
# Test Code
encoder = load_swin_encoder(input_shape=(512, 512, 3))
dummy_input = tf.random.uniform((1, 512, 512, 3))
encoder_outputs = encoder(dummy_input)
for i, output in enumerate(encoder_outputs):
print(f"Stage {i + 1} output shape: {output.shape}")
The Error:
The code throws the following error on TPU:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-28-3cb122d32678> in <cell line: 2>()
1 # loading Sanity check
----> 2 encoder = load_swin_encoder(input_shape=(512, 512, 3))
3 dummy_input = tf.random.uniform((1, 512, 512, 3))
4 encoder_outputs = encoder(dummy_input)
5
2 frames
/usr/local/lib/python3.10/dist-packages/keras/src/models/functional.py in __init__(self, inputs, outputs, name, **kwargs)
117 for x in flat_inputs:
118 if not isinstance(x, backend.KerasTensor):
--> 119 raise ValueError(
120 "All `inputs` values must be KerasTensors. Received: "
121 f"inputs={inputs} including invalid value {x} of "
ValueError: All `inputs` values must be KerasTensors. Received: inputs=KerasTensor(type_spec=TensorSpec(shape=(None, 512, 512, 3), dtype=tf.float32, name='input_4'), name='input_4', description="created by layer 'input_4'") including invalid value KerasTensor(type_spec=TensorSpec(shape=(None, 512, 512, 3), dtype=tf.float32, name='input_4'), name='input_4', description="created by layer 'input_4'") of type <class 'tf_keras.src.engine.keras_tensor.KerasTensor'>
Question:
Why does this code work on a CPU but fail on a TPU in Colab? How can I fix this issue to make it compatible with TPU execution?
Any insights or guidance would be greatly appreciated. Thank you!