본문 바로가기
논문리뷰

[논문 리뷰] iBOT 요약, 코드, 구현

by davidlds 2024. 7. 11.
반응형

논문을 상세히 번역하고 한단어씩 해석해주는 포스팅은 많다.

나는 논문을 누구나 알아듣도록 쉽고 간결하게 전달하고자 한다.

 

iBOT

iBOT: Image BERT Pre-Training with Online Tokenizer
ZHOU, Jinghao, et al. ibot: Image bert pre-training with online tokenizer. arXiv preprint arXiv:2111.07832, 2021.
 

 

저자의 의도

MIM 연구를 통해 visual tokenizer의 장점과 문제점을 연구해보자.
online tokenizer를 사용한 self-supervised 프레임워크 iBOT를 제시한다.
온라인의 의미는 사전 학습된 고정값이 아니라 모델과 함께 학습되는 것을 뜻한다.
마스크 패치 토큰에 self-distillation을 해서 teacher 네트워크로 사용해보자.
(online tokenizer = teacher 네트워크)

 

기존 문제점

MLM은 NLP에서 사실상의 표준이 되었지만 CV의 MIM은 아직 부족하다.
MLM의 가장 중요한 요소는 lingual tokenizer인데, visual tokenizer는 개선이 필요하다.
기존 연구에서는 pixel value를 예측했으나, 추상화된 의미를 학습하지 못하고 픽셀 디테일을 복원해야 하므로 비효율적이다.
BEiT와 같은 연구에서는 VAE를 visual tokenizer로 사용했지만, 이건 오프라인 데이터셋에서 학습되어 적응력에 한계가 있다.

 

해결 아이디어

1. iBOT overview

iBOT

image BERT pre-training with Online Tokenizer, iBOT 이다.
MIM을 knowledge distillation(KD)로 구성한다.
온라인 토크나이저를 twin teacher로 사용해 self-distillation 한다.
student 네트워크는 마스킹된 이미지를 받아 토크나이저의 아웃풋을 예측한다.

 

전체적인 프로세스가 진짜 진짜 진짜 진짜 복잡하다..... 반복해서 읽는 것을 추천한다.

먼저 기억할 것은 '이미지 토큰'과 '클래스 토큰'은 각각의 프로세스를 따로 가지고 있다.

 

2. MIM as Knowledge distillation

먼저 MIM 프로세스에 대한 설명이다.

이건 BEiT, MAE같은 마스킹과 매우 유사하기 때문에 쉽다.


이미지 토큰 시퀀스 x를 받으면, 0과 1으로 구성된 랜덤 마스크 m을 샘플링한다.
이때 m의 개수는 토큰 시퀀스와 같은 길이인 N 이다.
i번째 패치 토큰 x_i는 랜덤 마스크 m_i가 1일 경우 마스크 토큰 e로 대체된다.
이 결과가 손상시킨 이미지(corrupted image) x햇 이다.


MIM은 x햇에서 마스킹된 토큰 x틸드를 받아 원래 이미지 토큰 x을 복원(reconstruction) 한다.

Eq 1

이를 수학적으로 나타내면 Eq 1이 된다. (=BEiT)
여기서 기억해야 할 점은 이게 '이미지 토큰'에서 일어나는 일이다.

 

3. Self-distillation

DINO가 iBOT과 유사한 방법의 학습을 먼저 제안했으나 둘은 좀 다르다.

DINO와 iBOT 모두 MoCo의 방법을 응용하여 사용했다.

(MoCo -> DINO(MoCo 계승), iBOT(MoCo 개조)


MoCo와 DINO는 저자들이 사용하려는 것과는 다른 것을 teacher로 사용했다.

과거 이터레이션을 마치 teacher처럼 사용한다.

(저자들은 VAE를 사용하고자 함)

Eq 2

이를 수학적으로 나타내면 Eq 2가 된다. (DINO)
여기서 기억해야 할 점은 이게 '클래스 토큰'에서 일어나는 일이다.

 

MoCo가 궁금하다면 여기

 

4. iBOT

저자들은 Eq 1과 Eq 2에서 모티브를 따서 iBOT을 설계했다.
BEiT는 학습된 VAE를 사용했는데, 이를 온라인으로(모델과 같이) 학습을 시킨다면 더 자연스러워 질 것이라 생각했다.
다시 말하면 '온라인 비주얼 토크나이저'를 만들고자 했다.


self-distillation 아이디어의 경우 토큰 생성에 사용하고자 했다.

 

4-1. Framework

여기서부터 자꾸 왔다갔다 거려서 헷갈린다.... 반복해서 읽도록 하자....

self-distillation에서 가져온 부분과 MIM에서 가져온 부분을 각각 설명한다.

 

Framework

(self-distillation 부분 1)
서로 다른 augmentation을 시킨 u와 v가 있다.
여기에 blockwise masking으로 일부를 가려 u햇, v햇 이라 한다.

u햇을 student에 입력한 결과는 P_θ(u햇) 이다.
u햇을 teacher에 입력한 결과는 P_θ'(u햇) 이다.

 

Eq 3

training objective를 수식으로 나타내면 Eq 3이 된다.

training objective는 그냥 loss function 아니면 학습 목표라고 생각하면 된다.
u햇의 Eq 3과 v햇의 Eq 3을 평균을 내어 계산한다.

 

Framework

(MIM 부분 1)
그림을 잘 보면 클래스 토큰을 처리하는 헤드 h_CLS, 이미지 토큰을 처리하는 헤드 h_patch가 별개이다.
여기서 teacher의 이미지 토큰을 처리하는 h_patch_t는 비주얼 토크나이저 이다.
이 비주얼 토크나이저가 masked 패치 토큰에 대한 온라인 토큰을 만든다.
비주얼 토크나이저는 사전 학습을 위해 추가 단계가 필요하지 않다.
또한 이 토크나이저는 특정 데이터셋으로 학습한게 아니라 현재 데이터셋을 기준으로 도메인 knowledge를 추출한다.

 

Framework

(self-distillation 부분 2)
Fig 3의 클래스 토큰 부분을 보면 L_cls가 엑스 모양으로 표현되어 있다.
온라인 토크나이저가 유의미하게 만들기 위해 이렇게 했다.
크로스뷰로 클래스 토큰 간에 self-distillation 하는 것이다.
이 방법을 사용하면 시각적 의미가 부트스트래핑 된다... (필자도 이해 못함)
자세한 내용은 레퍼런스 논문을 보도록 하자.

 

(self-distillation 부분 3)
teacher는 back propagation 하지 않고 EMA를 통해서 업데이트 한다.

 

(추가 사항)
NLP의 '단어'와는 다르게 CV의 '이미지 패치'는 상대적으로 모호하다.
따라서 원핫 인코딩을 통한 토크나이제이션은 최적화되지 않을 수 있다.
원핫 인코딩하지 않은 (softmax 직전의 raw 확률 분포)를 사용했다.

 

결과 분석

1. Implementation

ViT를 백본 인코더 f로 사용, ImageNet-1K, image size 224, AdamW, batch size 1024,

(S 800/B 400/L 250) epochs, warm up, lr 5e-04, random MIM random 비율 0.1~0.5,

L_cls와 L_MIM은 1대1로 더함.

 

2. Experiment

Table 1~3

표식이 있는 경우 ImageNet-22K 학습한 것이다.
그 외에는 모두 ImageNet-1K만 학습한 것이다.

 

ViT-B/16은 k-NN, Lin 평가에서 가장 우수하다.

ImageNet-22K를 학습한 경우 Lin이 더 좋아진다.

 

Table 6

Object detection과 instance segmentation에서도 가장 좋은 성능이다.

 

코드 및 구현

오피셜 깃허브 코드 중 모델 부분

 

iBOT 헤드 부분 발췌

# Copyright (c) ByteDance, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import torch
import torch.nn as nn
import utils

from utils import trunc_normal_


class iBOTHead(DINOHead):

    def __init__(self, *args, patch_out_dim=8192, norm=None, act='gelu', last_norm=None, 
                 nlayers=3, hidden_dim=2048, bottleneck_dim=256, norm_last_layer=True, 
                 shared_head=False, **kwargs):
        
        super(iBOTHead, self).__init__(*args,
                                        norm=norm,
                                        act=act,
                                        last_norm=last_norm,
                                        nlayers=nlayers,
                                        hidden_dim=hidden_dim,
                                        bottleneck_dim=bottleneck_dim,
                                        norm_last_layer=norm_last_layer, 
                                        **kwargs)

        if not shared_head:
            if bottleneck_dim > 0:
                self.last_layer2 = nn.utils.weight_norm(nn.Linear(bottleneck_dim, patch_out_dim, bias=False))
                self.last_layer2.weight_g.data.fill_(1)
                if norm_last_layer:
                    self.last_layer2.weight_g.requires_grad = False
            else:
                self.mlp2 = nn.Linear(hidden_dim, patch_out_dim)
                self.last_layer2 = None

            self.last_norm2 = self._build_norm(last_norm, patch_out_dim, affine=False, **kwargs)
        else:
            if bottleneck_dim > 0:
                self.last_layer2 = self.last_layer
            else:
                self.mlp2 = self.mlp[-1]
                self.last_layer2 = None

            self.last_norm2 = self.last_norm

    def forward(self, x):
        if len(x.shape) == 2:
            return super(iBOTHead, self).forward(x)

        if self.last_layer is not None:
            x = self.mlp(x)
            x = nn.functional.normalize(x, dim=-1, p=2)
            x1 = self.last_layer(x[:, 0])
            x2 = self.last_layer2(x[:, 1:])
        else:
            x = self.mlp[:-1](x)
            x1 = self.mlp[-1](x[:, 0])
            x2 = self.mlp2(x[:, 1:])
        
        if self.last_norm is not None:
            x1 = self.last_norm(x1)
            x2 = self.last_norm2(x2)
        
        return x1, x2

 

 

 

관련 논문 리스트 (스크롤 내려서 Paper List 참고)

반응형