class ResidualDrop(nn.Module):
def __init__(self, nChannels, interm_channel, nOutChannels, deathRate=0.0, stride=1, ):
super(ResidualDrop, self).__init__()
self.deathRate = deathRate
self.conv1 = nn.Conv2d(nChannels, interm_channel, kernel_size=3, stride=stride, padding=1)
self.bn1 = nn.BatchNorm2d(interm_channel)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(interm_channel, nOutChannels, kernel_size=3, stride=1, padding=1)
self.bn2 = nn.BatchNorm2d(nOutChannels)
self.skip = nn.Sequential()
if stride != 1 or nOutChannels != nChannels:
self.skip = nn.Sequential(
nn.Conv2d(nChannels, nOutChannels, kernel_size=1, stride=stride),
nn.BatchNorm2d(nOutChannels)
)
def forward(self, x):
if self.training:
# Randomly decide to use the residual block or not based on the gate's status
if torch.rand(1) < self.deathRate:
# Gate is closed: Skip the operations of this residual block
return self.skip(x)
else:
# Gate is open: Perform the normal operations of the residual block
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = out + self.skip(x)
out = self.relu(out)
return out
else:
# During evaluation, always use the full network
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)'
out = out + self.skip(x)
out = (1-self.deathRate) * out + self.skip(x)
out = self.relu(out)
return out
Currently trying to reimplement stochastic depth using PyTorch, but I’m concerned about back propagation when the gate is closed. I don’t think I want the gradient to be updated in this case, but I’m not sure how to implement this. Can I get some advice with this?
New contributor
Danjx is a new contributor to this site. Take care in asking for clarification, commenting, and answering.
Check out our Code of Conduct.