Imagine a network like this
import torch
import torch.nn as nn
class CoolCNN(nn.Module):
def __init__(self):
super(CoolCNN, self).__init__()
self.initial_conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.parallel_conv = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, padding=1)
self.secondary_conv = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1)
self.max_pool = nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
self.fully_connected1 = nn.Linear(32 * 8 * 8, 128)
self.output_layer = nn.Linear(128, 10)
def forward(self, x):
main_path = self.max_pool(torch.relu(self.initial_conv(x)))
parallel_path = self.max_pool(torch.relu(self.parallel_conv(x)))
x = (main_path + parallel_path) / 2
x = self.max_pool(torch.relu(self.secondary_conv(x)))
x = x.view(-1, 32 * 8 * 8)
x = torch.relu(self.fully_connected1(x))
x = self.output_layer(x)
return x
It has this tree-like structure
CoolCNN
└── Forward Pass
├── parallel_path
│ ├── parallel_conv (Conv2d)
│ ├── ReLU Activation
│ └── max_pool (MaxPool2d)
│
├── main_path
│ ├── initial_conv (Conv2d)
│ ├── ReLU Activation
│ └── max_pool (MaxPool2d)
│
├── Average main_path and parallel_path
│
├── secondary_conv (Conv2d)
├── ReLU Activation
└── max_pool (MaxPool2d)
│
├── Flatten the Tensor
│
├── fully_connected1 (Linear)
├── ReLU Activation
│
└── output_layer (Linear)
As you can see, the main and parallel paths happen in parallel gets accurately represented in the drawn tree representation.
I wanted to know if it’s possible to somehow extract this tree-like structure from just a forward pass through the network without relying on the backward pass at all.
The only approach that I could find was forward hooks. That is in fact how torchsummary prints the model’s summary.
However, with a forward hook, all you get to know is which order the nodes get called in. So it is just a topological sort of the underlying graph and it’s impossible to accurately reconstruct a graph from just its topological sort. This means it’s impossible to accurately derive the actual unique tree representation from forward hooks alone (at least that is my understanding).
I have already seen this question where the OP seemed to be happy without an exact graph reconstruction, so it doesn’t address my need.
So is there a way by which we can reconstruct the computation graph from a forward pass alone?