I am new to pytorch, have started coding from one month.
this is my gru code
hidden_size = 32
gru_layers_count = 2
encoder = nn.GRU(hidden_size,
hidden_size,
num_layers = gru_layers_count,
batch_first = True, bidirectional=True)
ip = torch.randn(64, 100, hidden_size)
op, hn = encoder(ip)
print(op.shape, hn.shape)
the output is:
torch.Size([64, 100, 64]) torch.Size([4, 64, 32])
here I am actually concerned with the shape of hn
, its start dimension size is 4 so I am assuming it is 2 gru * 2 directions.
however I am a little confused on the arrangement.
So my question is,
is it like first 2 are forward and last 2 backward hidden states. or it is alternate forward and backward hidden states ?
is the following is the correct method to extract only forward gru states ?
forward_hidden = hn[[x for x in range(0, gru_layers_count * 2, 2)], :, :]