I’m trying to create a Dataloader class for a TensorFlow transformer project. I’ve included what I think are relevant parts of the class.
class DataLoader:
def __init__(self, ...):
...some data loading steps...
self.create_tf_dataset()
@tf.function
def create_tf_dataset(self):
self.dataset = tf.data.Dataset.range(self.dataset_size)
self.dataset = self.dataset.shuffle(buffer_size= 10000)
self.dataset = self.dataset.batch(self.batch_size)
self.dataset = self.dataset.map(self.index_tensor_to_data)
self.dataset = self.dataset.prefetch(1)
...
@tf.function
def index_tensor_to_data(self, index_tensor):
inputs = []
outputs = []
for index in index_tensor:
input_, output_ = self.index_to_data(index)
inputs.append(input_)
outputs.append(output_)
inputs = tf.convert_to_tensor(inputs)
outputs = tf.convert_to_tensor(outputs)
start_tag = -tf.ones((outputs.shape[0], 1, outputs.shape[2]))
contexts = tf.concat([start_tag, outputs[:, :-1,:]], axis = 1)
return (inputs, contexts), outputs
and I’m receiving the following error while trying to construct Dataloader
.
The tensor <tf.Tensor 'while/strided_slice_2:0' shape=(30, 2) dtype=float32> cannot be accessed from FuncGraph(name=index_tensor_to_data, id=10747356432), because it was defined in FuncGraph(name=while_body_24470, id=10747258512), which is out of scope.
I tried fiddling around with the @tf.function
decorator but clearly both functions are being registered as part of the Functional API Graph.
This is my first time working with TensorFlow, and I couldn’t find online documentation for creating Dataset classes of my own, like how there exists some for pytorch. Is my approach of mapping a tf.range object to my data tensors standard?