https://smp.readthedocs.io/en/latest/insights.html
I am attempting to create my own encoder for the smp
DeepLabV3+ model. I have some locally trained encoders that I wanted to use but first off, I want to try it with an already existing encoder, namely torchvision.models.segmentation.deeplabv3_resnet50
to ensure that I can test against something that I know works (my current model uses the in-built resnet50 backbone).
Attempting to follow the docs above I have generated the following code
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp
import torchvision.models.segmentation as segmentation
from torchvision.models.resnet import Bottleneck
class resnet50_encoder(nn.Module, smp.encoders._base.EncoderMixin):
def __init__(self, **kwargs):
super().__init__()
# Define your encoder module
self.encoder = segmentation.deeplabv3_resnet50(pretrained=True)
# Set the number of output channels for each feature tensor
self._out_channels = [3, 64, 64, 128, 256, 512]
# Set the depth (number of downsampling operations)
self._depth = 5
# Set the default number of input channels (usually 3 for RGB images)
self._in_channels = 3
def forward(self, x: torch.Tensor):
# Get features from the encoder
features = self.encoder(x)['out']
# Return features sorted in descending order of spatial resolution
return [features[f] for f in ['0', '1', '2', '3', '4', '5']]
smp.encoders.encoders["deeplab_resnet50"] = {
"encoder": resnet50_encoder,
"pretrained_settings": {
"imagenet": {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
"url": "https://download.pytorch.org/models/deeplabv3_resnet50_coco-cd0a2569.pth",
"input_space": "RGB",
"input_range": [0, 1],
},
},
"params": {
"out_channels": (3, 64, 256, 512, 1024, 2048),
"block": Bottleneck,
"layers": [3, 4, 6, 3],
},
}
However, when I try to then use this encoder (model = smp.DeepLabV3Plus(encoder_name="deeplab_resnet50")
) I get the following error (truncated)
Traceback (most recent call last): File "<string>", line 1, in <module> File "/da/aics/projects/ComPath/envs/sweenke4/compath_art_det_repl/lib/python3.9/site-packages/segmentation_models_pytorch/decoders/deeplabv3/model.py", line 146, in __init__ self.encoder = get_encoder( File "/da/aics/projects/ComPath/envs/sweenke4/compath_art_det_repl/lib/python3.9/site-packages/segmentation_models_pytorch/encoders/__init__.py", line 85, in get_encoder encoder.load_state_dict(model_zoo.load_url(settings["url"])) File "<string>", line 26, in load_state_dict File "/da/aics/projects/ComPath/envs/sweenke4/compath_art_det_repl/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format( RuntimeError: Error(s) in loading state_dict for resnet50_encoder:
Missing key(s) in state_dict: "encoder.backbone.conv1.weight", "encoder.backbone.bn1.weight", "encoder.backbone.bn1.bias", "encoder.backbone.bn1.running_mean", "encoder.backbone.bn1.running_var", "encoder.backbone.layer1.0.conv1.weight", "encoder.backbone.layer1.0.bn1.weight", "encoder.backbone.layer1.0.bn1.bias", ...
Unexpected key(s) in state_dict: "backbone.conv1.weight", "backbone.bn1.weight", "backbone.bn1.bias", "backbone.bn1.running_mean", "backbone.bn1.running_var", "backbone.bn1.num_batches_tracked", "backbone.layer1.0.conv1.weight", "backbone.layer1.0.bn1.weight", "backbone.layer1.0.bn1.bias", ...