I found at least three MLP-only architectures (that avoid using the attention mechanism) for computer vision that reported very high results on ImageNet (70%+) and other benchmarks, like CIFAR10 and CIFAR100 (paper 1 code, paper 2 code, paper 3 code) — from “g-mlp”, “MLP-Mixer”, and “do-you-even-need-attention”
I’ve been using paper 2’s code and trying to replicate the results (at least somewhat), but I can’t figure out what I’m missing (I’m getting around 50-60% on CIFAR10 and the papers report around 90%+).
Seems like the architecture is largely the same across all of these methods. Wondering what is it that I’m missing. I used every data augmentation I could think of, different optimizers, different batch sizes, and different number of layers to make the model larger/smaller.
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import numpy as np
from einops.layers.torch import Rearrange
import torchvision.transforms as T
from torchvision.transforms import InterpolationMode
from torchvision.transforms import v2
import math
import torch.nn.functional as F
from functools import partial
from timm.models.layers import DropPath, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg, Mlp
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dataset_name = 'cifar10'
root = '../data' # if not exist, download dataset
batch_size = 256
if dataset_name == 'cifar10':
# CIFAR10_trans = transforms.Compose([
# transforms.RandomApply([
# transforms.RandomInvert(p=0.1),
# transforms.RandomAdjustSharpness(sharpness_factor=1, p=0.8),
# transforms.RandomAutocontrast(p=0.5),
# transforms.RandomEqualize(p=0.6),
# transforms.RandomSolarize(threshold=2, p=0.5),
# transforms.RandomPosterize(bits=7, p=0.3),
# transforms.ColorJitter(brightness=0.6, contrast=0.6, saturation=0.6, hue=0.2),
# transforms.RandomRotation(degrees=2, interpolation=InterpolationMode.NEAREST, expand=False),
# transforms.RandomAffine(9, translate=(0.3, 0.3), scale=(0.5, 1.5), shear=8, interpolation=InterpolationMode.NEAREST, fill=(0, 0, 0)),
# ], p=0.5),
# transforms.ToTensor(),
# transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
# ])
CIFAR10_trans = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616])
])
train_set = datasets.CIFAR10(root=root, train=True, transform=CIFAR10_trans, download=True)
train_size = int(len(train_set) * 1.0) # 80% training data
valid_size = len(train_set) - train_size # 20% validation data
train_set, val_set = torch.utils.data.random_split(train_set, [train_size, valid_size])
test_set = datasets.CIFAR10(root=root, train=False, transform=CIFAR10_trans, download=True)
num_classes = 10
train_loader = torch.utils.data.DataLoader(
dataset=train_set,
batch_size=batch_size,
shuffle=True)
val_loader = torch.utils.data.DataLoader(
dataset=val_set,
batch_size=batch_size,
shuffle=False)
test_loader = torch.utils.data.DataLoader(
dataset=test_set,
batch_size=batch_size,
shuffle=False)
def requires_grad(module, requires_grad):
for p in module.parameters():
p.requires_grad = requires_grad
class LinearBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, num_tokens=197):
super().__init__()
# First stage
self.mlp1 = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)
self.norm1 = norm_layer(dim)
# Second stage
self.mlp2 = Mlp(in_features=num_tokens, hidden_features=int(
num_tokens * mlp_ratio), act_layer=act_layer, drop=drop)
self.norm2 = norm_layer(num_tokens)
# Dropout (or a variant)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
x = x + self.drop_path(self.mlp1(self.norm1(x)))
x = x.transpose(-2, -1)
x = x + self.drop_path(self.mlp2(self.norm2(x)))
x = x.transpose(-2, -1)
return x
class PatchEmbed(nn.Module):
""" Wraps a convolution """
def __init__(self, patch_size=16, in_chans=3, embed_dim=768):
super().__init__()
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x)
return x
class LearnedPositionalEncoding(nn.Module):
""" Learned positional encoding with dynamic interpolation at runtime """
def __init__(self, height, width, embed_dim):
super().__init__()
self.height = height
self.width = width
self.pos_embed = nn.Parameter(torch.zeros(1, embed_dim, height, width))
self.cls_pos_embed = nn.Parameter(torch.zeros(1, 1, embed_dim))
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_pos_embed, std=.02)
def forward(self, x):
B, C, H, W = x.shape
if H == self.height and W == self.width:
pos_embed = self.pos_embed
else:
pos_embed = F.interpolate(self.pos_embed, size=(H, W), mode='bilinear', align_corners=False)
return self.cls_pos_embed, pos_embed
class LinearVisionTransformer(nn.Module):
"""
Basically the same as the standard Vision Transformer, but with support for resizable
or sinusoidal positional embeddings.
"""
def __init__(self, *, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm,
positional_encoding='learned', learned_positional_encoding_size=(14, 14), block_cls=LinearBlock):
super().__init__()
# Config
self.num_classes = num_classes
self.patch_size = patch_size
self.num_features = self.embed_dim = embed_dim
# Patch embedding
self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
# Class token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
# Positional encoding
if positional_encoding == 'learned':
height, width = self.learned_positional_encoding_size = learned_positional_encoding_size
self.pos_encoding = LearnedPositionalEncoding(height, width, embed_dim)
else:
raise NotImplementedError('Unsupposed positional encoding')
self.pos_drop = nn.Dropout(p=drop_rate)
# Stochastic depth
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
self.blocks = nn.ModuleList([
block_cls(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, num_tokens=1 + (32 // patch_size)**2)
for i in range(depth)])
self.norm = norm_layer(embed_dim)
# Classifier head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
# Init
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
# Patch embedding
B, C, H, W = x.shape # B x C x H x W
x = self.patch_embed(x) # B x E x H//p x W//p
# Positional encoding
# NOTE: cls_pos_embed for compatibility with pretrained models
cls_pos_embed, pos_embed = self.pos_encoding(x)
# Flatten image, append class token, add positional encoding
cls_tokens = self.cls_token.expand(B, -1, -1)
x = x.flatten(2).transpose(1, 2) # flatten
x = torch.cat((cls_tokens, x), dim=1) # class token
pos_embed = pos_embed.flatten(2).transpose(1, 2) # flatten
pos_embed = torch.cat([cls_pos_embed, pos_embed], dim=1) # class pos emb
x = x + pos_embed
x = self.pos_drop(x)
# Transformer
for blk in self.blocks:
x = blk(x)
# Final layernorm
x = self.norm(x)
return x[:, 0]
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
@register_model
def linear_tiny(pretrained=False, **kwargs):
model = LinearVisionTransformer(
patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def linear_base(pretrained=False, **kwargs):
model = LinearVisionTransformer(
patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
@register_model
def linear_large(pretrained=False, **kwargs):
model = LinearVisionTransformer(
patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
model.default_cfg = _cfg()
return model
if __name__ == '__main__':
# Test
x = torch.randn(2, 3, 32, 32).to(device)
model = linear_tiny().to(device)
# model = linear_base().to(device)
out = model(x)
print('-----')
print(f'num params: {sum(p.numel() for p in model.parameters())}')
print(out.shape)
loss = out.sum()
loss.backward()
print('Single iteration completed successfully')
training_epochs = 100
criterion = nn.CrossEntropyLoss()
model_opt = optim.AdamW(model.parameters(), lr=1e-3, betas=(0.9, 0.99), weight_decay=5e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(model_opt, 300, eta_min=1e-6, last_epoch=-1, verbose=False)
# cutmix = v2.CutMix(num_classes=10)
# mixup = v2.MixUp(num_classes=10)
# cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])
# training loop
for epoch in tqdm(range(training_epochs)):
batch_losses = []
batch_accuracies = []
for batch_idx, (train_x, train_target) in enumerate(train_loader):
train_x, train_target = train_x.to(device), train_target.to(device)
# train_x, train_target = cutmix_or_mixup(train_x, train_target)
train_out = model(train_x)
_, train_predicted = torch.max(train_out.data, 1)
train_loss = criterion(train_out, train_target)
model_opt.zero_grad()
train_loss.backward()
model_opt.step()
scheduler.step()
batch_losses.append(train_loss.item())
# batch_accuracies.append(((train_predicted == train_target).sum().item() / train_predicted.size(0)))
# test loop
with torch.no_grad():
test_accuracy = []
for batch_idx, (test_x, test_target) in enumerate(test_loader):
test_x, test_target = test_x.to(device), test_target.to(device)
test_outputs = model(test_x)
_, test_predicted = torch.max(test_outputs.data, 1)
test_accuracy.append((test_predicted == test_target).sum().item() / test_predicted.size(0))
print(f'test accuracy: {np.mean(test_accuracy)}')
>>>
1%| | 1/100 [00:11<18:28, 11.20s/it]test accuracy: 0.31982421875
2%|▏ | 2/100 [00:22<18:05, 11.07s/it]test accuracy: 0.3609375
3%|▎ | 3/100 [00:33<17:45, 10.98s/it]test accuracy: 0.41474609375
4%|▍ | 4/100 [00:44<17:35, 10.99s/it]test accuracy: 0.47060546875
5%|▌ | 5/100 [00:54<17:21, 10.96s/it]test accuracy: 0.4900390625
6%|▌ | 6/100 [01:05<17:05, 10.90s/it]test accuracy: 0.47626953125
7%|▋ | 7/100 [01:16<16:58, 10.95s/it]test accuracy: 0.50859375
8%|▊ | 8/100 [01:27<16:41, 10.88s/it]test accuracy: 0.5208984375
9%|▉ | 9/100 [01:38<16:25, 10.83s/it]test accuracy: 0.4986328125
10%|█ | 10/100 [01:48<16:10, 10.78s/it]test accuracy: 0.52705078125
11%|█ | 11/100 [01:59<15:56, 10.75s/it]test accuracy: 0.544921875
12%|█▏ | 12/100 [02:10<15:48, 10.78s/it]test accuracy: 0.5205078125
13%|█▎ | 13/100 [02:21<15:38, 10.79s/it]test accuracy: 0.54580078125
14%|█▍ | 14/100 [02:32<15:28, 10.80s/it]test accuracy: 0.5607421875
15%|█▌ | 15/100 [02:42<15:18, 10.80s/it]test accuracy: 0.540234375
16%|█▌ | 16/100 [02:53<15:07, 10.80s/it]test accuracy: 0.55576171875
17%|█▋ | 17/100 [03:04<14:59, 10.83s/it]test accuracy: 0.566796875
18%|█▊ | 18/100 [03:15<14:48, 10.84s/it]test accuracy: 0.53642578125
19%|█▉ | 19/100 [03:26<14:37, 10.84s/it]test accuracy: 0.55302734375