I am learning pytorch
. In the code examples, the model is switched between training and testing by using the model.train()
and model.eval()
modes. I understand that this has to be done to deactivate training specific behaviour such as dropouts and normalisation and gradient computation while testing.
In many such examples, they also use torch.no_grad()
, which I understand is a way of explicitly asking to stop the calculations of the gradients.
My question is, if gradient calculation is stopped in model.eval()
mode then why do we also have to set torch.no_grad()
?
Below is an example of a code where both model.eval()
and torch.no_grad()
is used
class QuertyModel(nn.Module):
def __init__(self):
super().__init__()
self.input = nn.Linear(2, 8)
self.hl1 = nn.Linear(8, 16)
self.hl2 = nn.Linear(16, 32)
self.hl3 = nn.Linear(32, 16)
self.hl4 = nn.Linear(16, 8)
self.output = nn.Linear(8, 3)
def forward(self, x):
x = F.relu(self.input(x))
x = F.relu(self.hl1(x))
x = F.relu(self.hl2(x))
x = F.relu(self.hl3(x))
x = F.relu(self.hl4(x))
return self.output(x)
Function to train the model
def train_model(model, lr, num_epochs):
loss_function = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
train_acc, train_loss, test_acc, test_loss = [], [], [], []
for epoch in range(num_epochs):
print(f"{epoch+1}/{num_epochs}")
model.train() # switch to training mode
batch_acc, batch_loss = [], []
for X, y in train_loader:
# forward pass
y_hat = model(X)
loss = loss_function(y_hat, y)
# backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
matches = torch.argmax(y_hat, axis=1) == y # true/false
matches = matches.float() # 0/1
batch_acc.append(100*torch.mean(matches))
batch_loss.append(loss.item())
train_acc.append(np.mean(batch_acc))
train_loss.append(np.mean(batch_loss))
model.eval() # switch to evaluation mode
X, y = next(iter(test_loader))
with torch.no_grad():
y_hat = model(X)
test_acc.append(100*(torch.mean(y_hat) == y).float())
test_loss.append(loss_function(y_hat, y).item())
return train_acc, train_loss, test_acc, test_loss