I have a tf.keras model with multiple layers. In one of the layers, there are a lot of small linear algebra operations which are causing sparse GPU utilization (I observe this in the nsight systems profiling report). I want to use XLA in order to fuse these linear algebra operations and obtain a speedup.
Initially I tried using XLA auto grouping feature on the whole model. I was expecting to get a speedup. But on the contrary, I observed an increase in runtime!
I ran a simple experiment:
For each layer within the model, I created a unit test where I would wrap the layer within a keras model and apply xla auto grouping on that layer’s model specifically.
I observed that in this way, each unit test was sped up after using XLA auto grouping feature. However when I used XLA auto grouping on the original model, it was slowing down.
To counter this, I want to try using xla grouping on a single layer of the model which I know benifits the most from XLA auto grouping feature. I know that one way to do this is using @tf.function(jit_compile=True) to wrap the layer. But the problem with this is that this layer has a lot of tensorflow operations that are not supported by XLA. So I am having to do manual refactoring of code.
Is there a simpler method to use XLA grouping on only 1 layer of a tf.keras.model?