본문 바로가기
논문리뷰

[논문 리뷰] MLP mixer 요약, 코드, 구현

by davidlds 2023. 7. 15.
반응형

논문을 상세히 번역하고 한단어씩 해석해주는 포스팅은 많다.

나는 논문을 누구나 알아듣도록 쉽고 간결하게 전달하고자 한다.

 

MLP-Mixer

MLP-Mixer: An all-MLP Architecture for Vision
TOLSTIKHIN, Ilya O., et al. Mlp-mixer: An all-mlp architecture for vision. Advances in neural information processing systems, 2021, 34: 24261-24272.
 

MLP-Mixer 아키텍쳐
MLP-Mixer 아키텍쳐

저자의 의도

CV 분야에서 가장 보편적인 모델 2개가 CNN과 Transformer 이다.

근데 근디 근대 저자의 관점은...

convolution layer와 attention layer가 반드시 필요한건 아니다.

그래서 오직 MLP만 사용하는 MLP-Mixer architecture 제시하고자 했다.

 

기존 문제점

CNN과 ViT의 architecture가 너무 복잡하다.

틀린 말은 아니다. CNN과 ViT(Vision Transformer)는 복잡하다.

inductive bias나 gradient decent 등을 방지하기 위해

다양한 기법 및 다양한 레이어를 사용한다.

온전히 딥 레이어에 모든 것을 맡기면 어떻게 될까?

 

배경 지식

일단 패치 개념.

 

패치 개념
패치 개념

CV 분야에서 특히 ViT에서 이미지를 패치 단위로 쪼개서 넣는다.

패치는 그냥 이미지를 나눈 조각 하나를 부르는 용어다. 쫄지 말자.

 

패치 개념 보조 자료
패치 개념 보조 자료

이미지 사이즈가 384x384 다.

이걸 패치 사이즈 16x16으로 쪼개면 패치가 (384/16)^2 = 576개 나온다.

그리고 채널은 3개다. RGB 3개.

각 패치 1개는 16x16 사이즈에 채널이 3개니 모든 서브픽셀을 세면 768개 이다.

이 서브픽셀을 그냥 지금부터 embedding dimension 이라고 부르자.

그럼 위와 같은 행렬이 input 된다는 것이 이해가 될 것이다.

 

(이건 굳이 안 봐도 됨)

CNN의 대표주자 ResNet과 ViT를 참고하자.

 

ResNet 논문 리뷰, 논문 원문, 논문 요약, 논문 구현, Deep Residual Learning for Image Recognition

논문을 상세히 번역하고 한단어씩 해석해주는 포스팅은 많다. 나는 논문을 누구나 알아듣도록 쉽고 간결하게 전달하고자 한다. ResNet Deep Residual Learning for Image Recognition Kaiming He, Xiangyu Zhang, Shaoqin

davidlds.tistory.com

 

[논문 리뷰] Vision Transformer(ViT) | 논문 원문, 논문 요약, 논문 구현, AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMER

논문을 상세히 번역하고 한단어씩 해석해주는 포스팅은 많다. 나는 논문을 누구나 알아듣도록 쉽고 간결하게 전달하고자 한다. ViT AN IMAGE IS WORTH 16X16 WORDS:TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE DOSOVITS

davidlds.tistory.com

 

해결 아이디어

1. Only MLP

convolution 이나 self-attention 레이어를 사용하지 말자.

이 문장은 논문을 관통하는 가장 핵심적인 것이라 계속 반복해서 나온다.

MLP가 지역적 공간과 feature 채널에 반복적으로 작용되도록 설계할 것이다.

 

본격적으로 들어가기 앞서 용어 정리를 해주겠다.

이 저자 뿐만 아니라 온 세상에 용어를 섞어 써서 상당히 킹받는다.

뉴비들은 어쩌란 말인지 모르겠다. (뉴비 절단 멈춰!)

그래서 난 친절히 설명해줄거다.

패치 개념 설명 자료
패치 개념 설명 자료

위에서 설명한 패치와 픽셀은 이제부터 토큰과 채널이라고 부를거다.

이는 ViT의 기원이 BERT인 것과 딥러닝의 기원이 통계학인 것에 있는데...

이딴건 알 필요 없다. 그냥 받아들이자.

Token은 곧 패치(이미지 조각 하나) 이고, Channel은 곧 픽셀(서브픽셀 1개) 이다.

인풋
인풋

 

2. Mixing MLPs

Channel-mixing MLPs

채널 믹싱 MLP
채널 믹싱 MLP

2가지 mixing MLP 중에 channel-mixing MLP 이다.

서로 다른 채널 간의 커뮤니케이션 역할을 한다. (토큰(=패치)이 같으니까 공간적으로는 같은 위치)

각 토큰(=패치)는 독립적이며, 인풋 테이블의 열 방향으로 작용한다.

1x1 Conv으로 간주할 수 있다. (1x1 Conv는 채널을 줄이는데 사용 되는 레이어)

 

Token-mixing MLPs

토큰 믹싱 MLP
토큰 믹싱 MLP

다음으로 token-mixing MLP 이다.

서로 다른 토큰 간에 커뮤니케이션 역할을 한다. (토큰(=패치)이 다르니까 공간적 다른 위치)

각 채널(=픽셀)은 독립적이며, 인풋 테이블의 행 방향으로 작용한다.

single-channel depth-wise Conv으로 간주할 수 있다.

 

3. Architecture

아키텍쳐 오버뷰
아키텍쳐 오버뷰

패치 밑작업 -> Channel-mixing -> Norm -> Token-mixing -> GAP -> FC

복잡하다고 새로운 아키텍쳐를 창조한 저자 답게 매우 간단한 구조다.

 

4. 모델 별 아키텍쳐 특징 분석

Mixer의 아키텍쳐는 두 feature가 명확하게 나눠져 있다.

따라서 더 직관적이고 더 확실하게 문제점을 파악하고 분석할 수 있는 장점이 있다.

(물론 더 복잡한 개념을 이해하지 못하는 단점도 있다.)

 

Features at a given spatial location (i)

각 모델의 지역 내부의 feature를 다루는 레이어는 뭘까

CNN : 1x1 Conv 메인, large size kernel Conv 서포트
ViT : MLP 메인, self-attention 서포트
Mixer : channel-mixing MLP 단독

 

Features between different spatial locations (ii)

각 모델의 서로 다른 지역의 feature를 다루는 레이어는 뭘까
CNN : large size kernel Conv 단독
ViT : self-attention 단독
Mixer : token-mixing MLP 단독

 

Mixer를 보면 확실하게 두 개가 분리되어 동작한다.

 

실험 조건 및 결과

Pre-training

Resolution = 224, batch size = 4096, epochs = 300, learning rate = 0.001, linear decay

Dataset = JFT-300M or ImageNet-21k

 

Fine-tuning

Resolution = 224, Batch size = 512, steps = 7k, learning rate = 0.001, cosine decay

Dataset = ImageNet

 

실험결과 1
실험결과 1

 

Invariance to input permutations

실험결과 2
실험결과 2

내가 논문을 읽게 된 계기는 이거 때문인데.... 별 내용은 없다.

Mixer와 CNN 구조가 가지는 inductive bias를 비교한 실험이다.

 

실험 조건 : Mixer and ResNet on JFT-300M(pre-training) and ImageNet(5 shots)

Dataset : original, patch shuffling + pixel shuffling within each patch, pixel shuffling

오리지널 데이터셋(원본) / 패치+픽셀 셔플링(패치 내에서 픽셀 셔플 후 패치 셔플) / 픽셀 셔플링(걍 모든 픽셀 셔플링)

믹서는 패치+픽셀 셔플에 불변한다. ResNet의 강력한 inductive bias는 픽셀 순서에 의존한다.

딱 이정도 까지만 분석해놔서... 볼게 없다.

 

논문 구현

import torch
import torch.nn as nn
import torch.nn.functional as F

IMAGE_SIZE = 224
PATCH_SIZE = 16
NUM_CLASSES = 10
NUM_LAYERS = 8
HIDDEN_FEATURES = 512
TOKENS_MLP_DIM = 256
CHANNELS_MLP_DIM = 2048

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.fc2 = nn.Linear(hidden_features, out_features)

    def forward(self, x):
        x = F.gelu(self.fc1(x))
        x = self.fc2(x)
        return x

class MixerLayer(nn.Module):
    def __init__(self, num_patches, hidden_features, tokens_mlp_dim, channels_mlp_dim):
        super(MixerLayer, self).__init__()
        self.token_mixing = nn.Sequential(
            nn.LayerNorm(hidden_features),
            nn.Linear(num_patches, tokens_mlp_dim),
            nn.GELU(),
            nn.Linear(tokens_mlp_dim, num_patches)
        )
        self.channel_mixing = nn.Sequential(
            nn.LayerNorm(hidden_features),
            nn.Linear(hidden_features, channels_mlp_dim),
            nn.GELU(),
            nn.Linear(channels_mlp_dim, hidden_features)
        )

    def forward(self, x):
        y = x.permute(0, 2, 1)
        y = self.token_mixing(y)
        y = y.permute(0, 2, 1)
        x = x + y
        y = self.channel_mixing(x)
        x = x + y
        return x

class MLPMixer(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, num_layers, hidden_features, tokens_mlp_dim, channels_mlp_dim):
        super(MLPMixer, self).__init__()
        num_patches = (image_size // patch_size) ** 2
        self.patch_embedding = nn.Conv2d(3, hidden_features, kernel_size=patch_size, stride=patch_size)
        self.mixer_layers = nn.ModuleList([
            MixerLayer(num_patches, hidden_features, tokens_mlp_dim, channels_mlp_dim)
            for _ in range(num_layers)
        ])
        self.layer_norm = nn.LayerNorm(hidden_features)
        self.fc = nn.Linear(hidden_features, num_classes)

    def forward(self, x):
        x = self.patch_embedding(x).flatten(2).transpose(1, 2)
        for mixer_layer in self.mixer_layers:
            x = mixer_layer(x)
        x = self.layer_norm(x)
        x = x.mean(dim=1)
        x = self.fc(x)
        return x

if __name__ == '__main__':
    model = MLPMixer(image_size=IMAGE_SIZE,
                     patch_size=PATCH_SIZE,
                     num_classes=NUM_CLASSES,
                     num_layers=NUM_LAYERS,
                     hidden_features=HIDDEN_FEATURES,
                     tokens_mlp_dim=TOKENS_MLP_DIM,
                     channels_mlp_dim=CHANNELS_MLP_DIM,
                     )

간단하고 별거 없다.

 

끝.

 

관련 논문 리스트 (스크롤 내려서 Paper List 참고)

반응형