I have defined custom classifier head for transformers.TFRobertaForSequenceClassification
with two labels by doing as shown below, to be able to fine-tune the pretrained roberta-base
model for my downstream task of classifying sentences as coming from a finite set of independent labels. I would like to make that roberta-model a part of tensorflow.keras.Model
. Here’s what I now have using tensorflow.keras.layers.Identity
as the final/output layer inside tensorflow.keras.Model
, since the classification head of roberta already takes care of the final classification layer and the loss.
import numpy as np
import tensorflow as tf
from transformers import TFRobertaForSequenceClassification
# strategy to training and model definition and compiling done above
tf_gpu_strategy = tf.distribute.MirroredStrategy(devices=["/gpu:0", "/gpu:1"], cross_device_ops=tf.distribute.HierarchicalCopyAllReduce())
MAX_SEQUENCE_LENGTH = 256
with tf_gpu_strategy.scope():
# define the input-ids and attention-masks
input_word_ids = tf.keras.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int32, name='input_ids')
input_attention_mask = tf.keras.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int32, name='attention_mask')
input_token_type_ids = tf.keras.Input(shape=(MAX_SEQUENCE_LENGTH,), dtype=tf.int32, name='token_type_ids')
# initiate the pre-trained model
roberta_model = TFRobertaForSequenceClassification.from_pretrained(pretrained_model_path, from_pt=True, num_labels=num_output_nodes)
x = roberta_model(input_ids=input_word_ids, attention_mask=input_attention_mask, token_type_ids=input_token_type_ids, labels=labels)
# Huggingface transformers have multiple outputs, embeddings are the first one,
# this is inline in config.output_hidden_states as we want only the top head
x = x[0]
# add the final layers needed for the task
final_layer = tf.keras.layers.Identity()(x)
# construct the model
model = tf.keras.Model(inputs=[input_word_ids, input_attention_mask, input_token_type_ids], outputs=final_layer)
# compile the model
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction='sum_over_batch_size'), #when `from_logits=False`, it is normalized by softmax function
optimizer=tf.keras.optimizers.Adam(lr=1e-5),
metrics=['accuracy', tf.keras.metrics.BinaryAccuracy(), tf.keras.metrics.Precision()])
I need the model to be in Keras, because I have downstream code that processes the history that tf.keras.Model.fit()
returns.
Issues I see with my existing approach is, as shown in the source code,
And, when I fit,
with tf_gpu_strategy.scope():
history_fit = model.fit(x=train_data_tf_dataset, y=None,
epochs=NUM_EPOCHS, batch_size=BATCH_SIZE, validation_data=(val_data_tf_dataset),
verbose=1)
I get the following error:
ValueError: Cannot generate a hashable key for DistributedIteratorSpec(((('/job:localhost/replica:0/task:0/device:CPU:0', ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')),), True), ({'input_ids': PerReplicaSpec(TensorSpec(shape=(None, 256), dtype=tf.int32, name=None), TensorSpec(shape=(None, 256), dtype=tf.int32, name=None)), 'token_type_ids': PerReplicaSpec(TensorSpec(shape=(None, 256), dtype=tf.int32, name=None), TensorSpec(shape=(None, 256), dtype=tf.int32, name=None)), 'attention_mask': PerReplicaSpec(TensorSpec(shape=(None, 256), dtype=tf.int32, name=None), TensorSpec(shape=(None, 256), dtype=tf.int32, name=None))}, PerReplicaSpec(TensorSpec(shape=(None,), dtype=tf.int32, name=None), TensorSpec(shape=(None,), dtype=tf.int32, name=None)), PerReplicaSpec(TensorSpec(shape=(None,), dtype=tf.float64, name=None), TensorSpec(shape=(None,), dtype=tf.float64, name=None))), 139793163147184, 94333082522592) because the _serialize() method returned an unsupproted value of type <class 'transformers.tokenization_utils_base.BatchEncoding'>
Questions:
- Is the model definition correct? Is using Identity as a final layer without an activation function the correct approach, given that the final-layer of Roberta with the custom classification head has the loss, as mentioned in the source code?
- What is causing the above error and what do I do to fix it?
- Am I doing something wrong somewhere else?