Right now, I’m training an RNN network for my DGPS application, taking latitude, longitude and altitude. Here’s the network architecture:
# Define our network class by using the nn.module
class ResBlockMLP(nn.Module):
def __init__(self, input_size, output_size):
super(ResBlockMLP, self).__init__()
self.norm1 = nn.LayerNorm(input_size)
self.fc1 = nn.Linear(input_size, input_size//2)
self.norm2 = nn.LayerNorm(input_size//2)
self.fc2 = nn.Linear(input_size//2, output_size)
self.fc3 = nn.Linear(input_size, output_size)
self.act = nn.ELU()
def forward(self, x):
x = self.act(self.norm1(x))
skip = self.fc3(x)
x = self.act(self.norm2(self.fc1(x)))
x = self.fc2(x)
return x + skip
class RNN(nn.Module):
def __init__(self, seq_len, output_size, num_blocks=1, buffer_size=128):
super(RNN, self).__init__()
seq_data_len = seq_len * 2
self.input_mlp = nn.Sequential(nn.Linear(seq_data_len, 4 * seq_data_len),
nn.ELU(),
nn.Linear(4 * seq_data_len, 128),
nn.ELU(),)
self.rnn = nn.Linear(256, 128)
blocks = [ResBlockMLP(128, 128) for _ in range(num_blocks)]
self.res_blocks = nn.Sequential(*blocks)
self.fc_out = nn.Linear(128, output_size)
self.fc_buffer = nn.Linear(128, buffer_size)
self.act = nn.ELU()
def forward(self, input_seq, buffer_in):
input_seq = input_seq.reshape(input_seq.shape[0], -1)
input_vec = self.input_mlp(input_seq)
# Concatenate the previous step buffer
x_cat = torch.cat((buffer_in, input_vec), 1)
x = self.rnn(x_cat)
x = self.act(self.res_blocks(x))
return self.fc_out(x), torch.tanh(self.fc_buffer(x))
However, this code line throws an error:
data_pred, buffer = gps_rnn(seq_block, buffer)
Error:
File D:ProgramDataMiniconda_3.9envsrnn-sample-py3.9libsite-packagestorchnnmoduleslinear.py:114, in Linear.forward(self, input)
113 def forward(self, input: Tensor) -> Tensor:
--> 114 return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (32x126 and 1600x6400)
Someone told me to use the torchsummary module to see how a tensor flows through your network. I got the input shape (32,14,2) by running print(seq_block.size())
.
I do have problems running the torchsummary module though:
from torchsummary import summary
summary(gps_rnn, (32,14,2))
Error: https://pastebin.com/Lt9rZD3y
Also tried the torchinfo package since it was updated and got the following result instead:
from torchinfo import summary
summary(gps_rnn, input_size=(batch_size, 32, 14, 2))
Error: https://pastebin.com/rmuSH0j7
I also tried this solution to pass two arguments into the summary function but it also throws an error: https://pastebin.com/tma4cWyN
How do I solve this problem and train the network? Your help is very much appreciated.
Kenneth Ligutom is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.