논문을 상세히 번역하고 한단어씩 해석해주는 포스팅은 많다.
나는 논문을 누구나 알아듣도록 쉽고 간결하게 전달하고자 한다.
MambaOut
엄청난 어그로성 제목의 논문이 나왔다.
(이정도면 유투바 아니냐고....)
접근이 신선하고 비슷한 생각을 해본 적이 있기 때문에 읽었는데...
결론이 좀 허망하게 난 경향이 있다.
너무 기대하고 읽을 필요는 없겠다.
저자의 의도
Mamba는 어텐션 매커니즘의 2차함수 복잡성을 해결하는데 매우 뛰어나다.
근데 맘바는 긴 시퀀스와 autoregressive 특성의 task에서만 유리하다.
(image classification task는 둘 다 아니므로 부적합하다.)
(detection과 segmentation task는 긴 시퀀스만 적합하다.)
이 가설을 증명해보자.
맘바 블럭을 쌓긴 하지만 SSM 모듈은 빠진 MambaOut 모델을 만들어 증명해보자.
기존 문제점
트랜스포머는 긴 시퀀스에 불리하다.
이를 해결하기 위해 나온 맘바는 비전에 맞게 개조해야 한다.
Vision mamba, Vmamba 등이 나왔지만 여전히 퍼포먼스가 높지 않다.
Do we really need Mamba for Vision?
진짜 vision에 맘바가 필요한가??? 라는 의문이 나올 수 밖에 없다.
해결 아이디어
1. MambaOut
맘바 블럭에서 SSM 모듈을 뺀 Gated CNN 구조를 만들었다.
이 모델을 SSM을 포함한 비전 모델 Vim, Vmamba 등과 비교해보자.
비전에서 NLP 차용 모델이 나오면 무조건 들어오는 공격이다.
이거 필요해? 이건? 이건 빼도 될껄? 이건 왜 넣어?
경험적으로는 '다 필요해'가 정답이지만 비판적 사고는 늘 필요하다.
그리고 이런 과정에서 모듈의 역할을 잘 분석할 수 있다.
이 논문은 그런 내용인 것이다.
저 비전 맘바에 대한 내용이 궁금하면 여기 아래.
https://davidlds.tistory.com/38
[논문 리뷰] Vision Mamba(Vim) 요약, 코드, 구현
논문을 상세히 번역하고 한단어씩 해석해주는 포스팅은 많다. 나는 논문을 누구나 알아듣도록 쉽고 간결하게 전달하고자 한다. Vision Mamba ZHU, Lianghui, et al. Vision mamba: Efficient visual representation learn
davidlds.tistory.com
2. Conceptual discussion
2-1. What tasks is Mamba suitable for?
어어 쫄지 말자. 난 수식을 자세히 설명할 생각이 없다.
최대한 간결하게 설명할 것이다.
맘바는 selective SSM을 사용한다.
SSM을 수식으로 설명하면 Equation (1), (2), (3)으로 표현할 수 있다.
Eq 2의 재귀적인 특징이 SSM(=RNN like)과 어텐션 구조를 구별하는 특징이다.
뭔소린지 모르겠으니 이제 그림으로 알아보자.
2가지 특징을 볼거다. Long sequence, Autoregressive 2개다.
[Long sequence]
먼저 오른쪽 그림을 보자. 이건 RNN-like 인 SSM 이다.
여기서 hidden state인 h는 크기가 고정되어 있고 여기에 모든 이력 정보를 다 때려넣는다.
이 고정된 크기는 메모리(컴퓨터 메모리 말고 기억력 메모리)에 필연적으로 손실이 온다는 것을 의미한다.
하지만 고정된 크기는 연산의 복잡성을 줄여주는 것도 보장해준다.
이제 왼쪽 그림을 보자. 이건 Causal 인 어텐션 블럭 이다.
반면에 어텐션 구조는 모든 key와 value를 메모리에 저장한다.
이 '커지는 메모리'는 이론적으로 손실이 없다.
하지만 '커지는 메모리'는 연산의 복잡성이 점점 커지는 것을 의미한다.
결론은 SSM의 구조는 본질적으로 손실이 발생한다.
단기 기억 쪽에서는 이 손실 발생 아키텍처가 당연하게도 어텐션 구조의 무손실 기억보다 약할 수 밖에 없다.
그래서 짧은 시퀀스에서는 맘바가 어텐션 구조를 이길 수 없다.
하지만 긴 시퀀스 상황이 오면 어텐션 구조는 이차함수적 복잡성에 의해 흔들림이 생긴다.
맘바는 메모리를 효율적으로 합치고 긴 시퀀스를 스무스하게 관리할 수 있다.
그래서 긴 시퀀스에서 맘바는 뚜렷한 장점을 보인다.
[Autoregressive]
왼쪽 그림을 보자.
SSM의 큰 문제는 h는 오직 이번과 직전 정보만 접근할 수 있다.
이런 조건의 모델을 causal mode라고 명명하여 예를 들었다.
causal mode는 본인 직전까지의 토큰만 접근할 수 있다.
예를들어 GPT's attention, Mamba's SSM이 있다.
이 모드는 autoregressive한 generation task에 적합하다.
반면에 fully-visible mode는 이후의 토큰까지 다 접근할 수 있다.
예를들어 BERT's attention, ViT's attention이 있다.
이 모드는 시퀀스 전체를 이해하는 understanding task에 적합하다.
어텐션은 fully-visible mode가 기본이고 손쉽게 causal mode로 바꿀 수 있다.
하지만 SSM은 본질적으로 이 모드를 넘나들 수 없다.
따라서 맘바는 autoregressive한 task에 적합하다.
2-2. Do visual recognition tasks have very long sequences?
이 부분은 ImageNet의 이미지는 224 밖에 안되기 때문에 긴 시퀀스일 필요가 없다는 말을 한다.
난 동의하지 않기 때문에 자세한 부분은 넘어가겠다.
이 224 사이즈는 단순히 classification task에서 편의상 비교하기 쉽게 통일한 것이다.
실제로 이를 적용해 서비스에 사용할 때는 긴 시퀀스에 많은 어려움을 겪고 있다.
따라서 이런식으로 해석하는건 옳지 않다.
2-3. Do visual recognition tasks need causal token mixing mode?
이미지 인식 분야는 시퀀스 전체를 이해하는 understanding task 이다.
이후 토큰을 볼 수 없는 causal mode로 바꾸면 기능 저하가 생길 수 있다.
visual recognition task는 causal mode로 바꿀 필요가 없다.
예를 들면 트랜스포머는 전체 이미지를 조각내 토큰으로 만든다.
이 토큰을 다 보고 이미지가 어떤 class인지 맞추게 한다는 것이다.
반면에 causal mode의 GPT같은 애들은 직전 토큰을 보고 다음 토큰을 예측한다.
각각의 아키텍처에 딱 맞는 task가 존재하기 때문에 이를 바꿀 필요가 없다는 말이다.
visual task는 모든 토큰을 보는게 맞는 아키텍처 이다.
2-4. Hypotheses regarding the necessity of Mamba for vision
아무튼 저자들의 주장을 정리해보자.
롱 시퀀스, autoregressive 2가지 특징이 없는 image classification에 SSM은 필요하지 않다.
롱 시퀀스 특징은 있고 autoregressive 특징은 없는 detection이나 segmentation에 SSM은 사용해볼 가치가 있다.
2-5. Gated CNN and MambaOut
MambaOut 모델은 Gated CNN 블럭을 베이스로 한다.
Gated CNN과 Mamba의 가장 큰 차이는 SSM이 있고 없고 이다.
결과 분석
1. Image classification on ImageNet
비슷한 크기끼리 비교했을 때 SSM을 포함한 Mamba들 보다 SSM은 지운 MambaOut이 더 좋은 성능을 보인다.
classification에서는 SSM이 불필요하므로 제거하는게 좋다.
(나는 저자의 주장에 동의하지 않는다. 224 이미지로 실제 기술에 적용할 수 있는건 없다. 이건 단순히 평가일 뿐이다. 실제 기술에서는 최소 1080p 사이즈는 될 것이다. 이 task는 모델이 얼마나 이미지를 잘 해석 하는지를 보는 것이지 저 accuracy 0.X%가 중요한게 아니다. ImageNet 평가는 절대적이지 않다. 그걸 보완하기 위해서 다양한 추가 데이터셋도 나왔다. 단순히 이 acc를 보고 모델이 좋아졌다고 판단하기에는 간과한 부분이 너무 많다.)
2. Object detection & instance segmentation on COCO
MambaOut이 Mamba를 넘지 못했다.
롱 시퀀스 특징이 있는 detection에서는 SSM이 필요하다.
3. Semantic segmentation on ADE20K
MambaOut이 Mamba를 넘지 못했다.
롱 시퀀스 특징이 있는 segmentation에서는 SSM이 필요하다.
코드 및 구현
중 Gated CNN Block 발췌
class GatedCNNBlock(nn.Module):
r""" Our implementation of Gated CNN Block: https://arxiv.org/pdf/1612.08083
Args:
conv_ratio: control the number of channels to conduct depthwise convolution.
Conduct convolution on partial channels can improve paraitcal efficiency.
The idea of partial channels is from ShuffleNet V2 (https://arxiv.org/abs/1807.11164) and
also used by InceptionNeXt (https://arxiv.org/abs/2303.16900) and FasterNet (https://arxiv.org/abs/2303.03667)
"""
def __init__(self, dim, expansion_ratio=8/3, kernel_size=7, conv_ratio=1.0,
norm_layer=partial(nn.LayerNorm,eps=1e-6),
act_layer=nn.GELU,
drop_path=0.,
**kwargs):
super().__init__()
self.norm = norm_layer(dim)
hidden = int(expansion_ratio * dim)
self.fc1 = nn.Linear(dim, hidden * 2)
self.act = act_layer()
conv_channels = int(conv_ratio * dim)
self.split_indices = (hidden, hidden - conv_channels, conv_channels)
self.conv = nn.Conv2d(conv_channels, conv_channels, kernel_size=kernel_size, padding=kernel_size//2, groups=conv_channels)
self.fc2 = nn.Linear(hidden, dim)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x # [B, H, W, C]
x = self.norm(x)
g, i, c = torch.split(self.fc1(x), self.split_indices, dim=-1)
c = c.permute(0, 3, 1, 2) # [B, H, W, C] -> [B, C, H, W]
c = self.conv(c)
c = c.permute(0, 2, 3, 1) # [B, C, H, W] -> [B, H, W, C]
x = self.fc2(self.act(g) * torch.cat((i, c), dim=-1))
x = self.drop_path(x)
return x + shortcut
'논문리뷰' 카테고리의 다른 글
[논문 리뷰] CAE(Context Autoencoder) 요약, 코드, 구현 (0) | 2024.06.25 |
---|---|
[논문 리뷰] LLaVA-UHD 요약, 코드, 구현 (1) | 2024.06.19 |
[논문 리뷰] MoCo v3 요약, 코드, 구현 (0) | 2024.06.04 |
[논문 리뷰] MoCo v2 요약, 코드, 구현 (0) | 2024.05.27 |
[논문 리뷰] MoCo v1 요약, 코드, 구현 (0) | 2024.05.22 |