I am trying to run a neural network model in PyTorch, but I am encountering a RuntimeError related to tensor reshaping. The error occurs when I attempt to reshape a tensor within the Conv class of my model. Here’s the specific error message:
RuntimeError: shape ‘[-1, 205, 202]’ is invalid for input of size 606
class Conv(nn.Module):
def __init__(self, conv1d_1, conv1d_2, maxpool1d_1, maxpool1d_2, fc_1_size, fc_2_size):
super(Conv, self).__init__()
self.conv1d_1_args = conv1d_1
self.conv1d_1 = nn.Conv1d(**conv1d_1)
self.conv1d_2 = nn.Conv1d(**conv1d_2)
fc1_size = get_conv_mp_out_size(fc_1_size, conv1d_2, [maxpool1d_1, maxpool1d_2])
fc2_size = get_conv_mp_out_size(fc_2_size, conv1d_2, [maxpool1d_1, maxpool1d_2])
self.fc1 = nn.Linear(fc1_size, 1)
self.fc2 = nn.Linear(fc2_size, 1)
self.drop = nn.Dropout(p=0.2)
self.mp_1 = nn.MaxPool1d(**maxpool1d_1)
self.mp_2 = nn.MaxPool1d(**maxpool1d_2)
def forward(self, hidden, x):
concat = torch.cat([hidden, x], 1)
print(f"Concat shape: {concat.shape}")
batch_size = concat.shape[0]
concat_size = hidden.shape[1] + x.shape[1]
print(f"Expected shape after reshaping: (-1, {self.conv1d_1_args['in_channels']}, {concat_size})")
total_elements_before = concat.numel()
total_elements_after = batch_size * self.conv1d_1_args["in_channels"] * concat_size
if total_elements_before != total_elements_after:
raise ValueError(f"Shape mismatch: cannot reshape tensor of total elements {total_elements_before} "
f"to (-1, {self.conv1d_1_args['in_channels']}, {concat_size}), which has {total_elements_after} elements.")
concat = concat.view(-1, self.conv1d_1_args["in_channels"], concat_size)
print(f"Concat reshaped to: {concat.shape}")
Z = self.mp_1(F.relu(self.conv1d_1(concat)))
print(f"Shape after first conv and maxpool: {Z.shape}")
Z = self.mp_2(self.conv1d_2(Z))
print(f"Shape after second conv and maxpool: {Z.shape}")
hidden = hidden.view(-1, self.conv1d_1_args["in_channels"], hidden.shape[1])
print(f"Hidden reshaped to: {hidden.shape}")
Y = self.mp_1(F.relu(self.conv1d_1(hidden)))
print(f"Shape after first conv and maxpool on hidden: {Y.shape}")
Y = self.mp_2(self.conv1d_2(Y))
print(f"Shape after second conv and maxpool on hidden: {Y.shape}")
Z_flatten_size = int(Z.shape[1] * Z.shape[-1])
Y_flatten_size = int(Y.shape[1] * Y.shape[-1])
print(f"Z flatten size: {Z_flatten_size}, Y flatten size: {Y_flatten_size}")
Z = Z.view(-1, Z_flatten_size)
Y = Y.view(-1, Y_flatten_size)
res = self.fc1(Z) * self.fc2(Y)
res = self.drop(res)
sig = torch.sigmoid(torch.flatten(res))
return sig
Here is how I am initializing and using the model:
import torch
from torch_geometric.data import Data
# Import the Net class from the model module
from src.process.model import Net
# Example parameters for initializing the Net class
gated_graph_conv_args = {
'out_channels': 200,
'num_layers': 6
}
conv_args = {
'conv1d_1': {'in_channels': 205, 'out_channels': 50, 'kernel_size': 3, 'stride': 1},
'conv1d_2': {'in_channels': 50, 'out_channels': 20, 'kernel_size': 1, 'stride': 1},
'maxpool1d_1': {'kernel_size': 2, 'stride': 2},
'maxpool1d_2': {'kernel_size': 2, 'stride': 2}
}
emb_size = 100
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Initialize the model with adjusted parameters
the_model = Net(gated_graph_conv_args, conv_args, emb_size, device)
# Load the state dictionary
model_path = PATHS.model + "checkpoint.pt"
state_dict = torch.load(model_path)
# Load the state dictionary into the model
the_model.load_state_dict(state_dict)
# Set the model to evaluation mode
the_model.eval()
# Example node features (x) and edge indices (edge_index)
x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=torch.float)
edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]], dtype=torch.long)
# Create a Data object
data = Data(x=x, edge_index=edge_index)
# Pass the Data object to the model
output = the_model(data)
# Print the model output
print(output)
The error occurs specifically at the line:
concat = concat.view(-1, self.conv1d_1_args["in_channels"], concat_size)
How can I resolve the shape mismatch issue during reshaping in my PyTorch model? Any suggestions on best practices for handling such tensor reshaping issues would be greatly appreciated, as I am new to this.
Thank you for your help!