I’ve trained a RaResNet50 from here: robust principles. I used this command: make experiments/RaResNet50/.done_train
and then I would want to check the robustness by evaluating on CIFAR-10.
So I did this:
import foolbox as fb
import torch
import torch.nn as nn
import numpy as np
import os
import hydra
from robustarch.models.model import NormalizedConfigurableModel
MODEL_PATH = '/home/name/robust-principles/trained_models/RaResNet50_fat_train_eps5/phase3/model_best.pt'
def run_attack():
mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]
kwargs = {
'stage_widths': [288, 576, 1120, 2160],
'kernel': 3,
'strides': [2, 2, 2, 2],
'dilation': 1,
'norm_layer': [
hydra.utils.get_class('torch.nn.Identity'),
hydra.utils.get_class('torch.nn.BatchNorm2d'),
hydra.utils.get_class('torch.nn.BatchNorm2d')
],
'activation_layer': [
hydra.utils.get_class('torch.nn.SiLU'),
hydra.utils.get_class('torch.nn.SiLU'),
hydra.utils.get_class('torch.nn.SiLU')
],
'group_widths': [36, 72, 140, 270],
'bottleneck_multipliers': [0.25, 0.25, 0.25, 0.25],
'downsample_norm': hydra.utils.get_class('torch.nn.BatchNorm2d'),
'depths': [5, 8, 13, 1],
'dense_ratio': None,
'stem_type': hydra.utils.get_class('robustarch.models.model.Stem'),
'stem_width': 96,
'stem_kernel': 7,
'stem_downsample_factor': 2,
'stem_patch_size': None,
'block_constructor': hydra.utils.get_class('robustarch.models.model.BottleneckBlock'),
'ConvBlock': hydra.utils.get_class('robustarch.models.model.Conv2dNormActivation'),
'se_ratio': 0.25,
'se_activation': hydra.utils.get_class('torch.nn.ReLU'),
'weight_init_type': 'resnet',
'num_classes': 1000
}
# I got these configurations from printing them in the NormalizedConfigurableModel class when I started my training first
model = NormalizedConfigurableModel(
mean=mean,
std=std,
**kwargs
)
checkpoint = torch.load(MODEL_PATH)
model.load_state_dict(checkpoint)
...
But I got this error:
File "/home/name/anaconda3/envs/ra-principles/lib/python3.9/site-packages/torch/nn/modules/module.py", line 2189, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:nt{}'.format(
RuntimeError: Error(s) in loading state_dict for NormalizedConfigurableModel:
Missing key(s) in state_dict: "stem.stem.stem.0.weight", "stem.stem.stem.1.weight", "stem.stem.stem.1.bias",.....
it appears in this line of my code
model.load_state_dict(checkpoint)
The error message is a very big list of missing keys. What did I do wrong?