I’m refactoring a project written in Jax. This is my first time working with Jax.
There’s a function, let’s call it foo
, that gets fed into jax.lax.scan
. It has an argument bar
that is currently part of the carry (i.e. the first argument, which is a tuple of different variables that gets passed ahead to the next call.) I noticed that the bar
argument doesn’t change throughout a single scan, i.e. the function unpacks it from the received carry and packs it into the returned carry without modification. I figured I better remove it from the carry, but I couldn’t figure out how to do that. Is there a recommended way to do that?
At first I tried removing it from the carry, adding it as a keyword argument to the function and changing the scan
call to use partial(foo, bar=bar)
. However, I noticed it was slow. The foo
function is jitted and I’m guessing that the way I added the argument makes it be jitted every single time instead of just once.
I then tried to feed bar
into the xs
argument, but I got builtins.ValueError: scan got value with no leading axis to scan over: 0, 0, 0, 0, 0, 0, 0, 0.
Any idea what’s the right way to do this?