I’ve been using an older version of Flax, and was able to incorporate existing RNNCells into a custom RNNCell where I can keep its carry in case I want to:
class CustomRNNCell(nn.Module):
RNNCell: nn.recurrent.RNNCellBase
hidden_dims: int
output_dims: int
@nn.compact
def __call__(self, carry, inputs, train: bool):
x = nn.Dense(features=self.hidden_dims)(x)
x = nn.leaky_relu(x)
carry, x = self.RNNCell()(carry, x)
x = nn.Dense(features=self.hidden_dims//2)(x)
x = nn.leaky_relu(x)
x = nn.Dense(features=self.output_dims)(x)
x = nn.sigmoid(x)
return carry, x
def initialize_carry(self, rng, batch_dims, init_fn=nn.zeros):
return self.RNNCell.initialize_carry(rng, batch_dims, self.hidden_dims, init_fn)
This worked because initialize_carry
was a class function. Now I upgraded to the new version, where I got errors because the init function was moved to be an instance function. I’m not sure what the easiest way would be to implement the previous behavior I had, since the RNNCell is only instantiated inside the call function…
I can’t have something like what I’ve seen in examples, since then there’s a conflict between the carry coming as an input and one defined through initialize_carry:
class CustomRNNCell(nn.Module):
RNNCell: nn.recurrent.RNNCellBase
hidden_dims: int
output_dims: int
@nn.compact
def __call__(self, carry, inputs, train: bool):
x = nn.Dense(features=self.hidden_dims)(x)
x = nn.leaky_relu(x)
rnn_cell = self.RNNCell()
carry = rnn_cell.initialize_carry(...) # this would overwrite the input carry?
carry, x = self.RNNCell()(carry, x)
x = nn.Dense(features=self.hidden_dims//2)(h)
x = nn.leaky_relu(x)
x = nn.Dense(features=self.output_dims)(x)
x = nn.sigmoid(x)
return carry, x
I see a way around this by using setup
to instantiate the RNNCell, and then having the initialize_carry
function as before use that. But I would prefer to avoid that since then I have to define everything in the setup function myself.
Is there a way to retain using the compact notation to get the carry initialization working as before? I’m guessing this change of having initialize_carry as a class vs instance method was introduced to make things simpler, not more complicated, so I feel I’m missing something.