논문을 상세히 번역하고 한단어씩 해석해주는 포스팅은 많다.
나는 논문을 누구나 알아듣도록 쉽고 간결하게 전달하고자 한다.
UNet
저자의 의도
데이터 수가 적어도 높은 정확도를 보이는 FCN을 만들고자 했다.
기존 문제점
저자들은 생물학에서 CNN을 활용할 때 문제점 때문에 이 네트워크를 구상했다.
생물학에서는 segmentation에 CNN에 활용한다.
위치정보를 잘 보존해야 하는데 일반적인 CNN은 그렇지 않다.
그래서 나온게 FCN.
FCN은 인코더 디코더 구조를 사용해서 비교적 공간정보를 잘 보존한다.
FCN은 2가지 큰 문제점이 있다.
1. 느리다.
이미지 인식 + 공간정보 인식을 위해 분할작업을 하는데,
너무 많이 분할하기 때문에 모델이 매우 느리다.
2. 패치 사이즈에 따라 인식 사이즈도 바뀐다.
패치 사이즈가 작으면 작은 객체만 인식하고,
패치 사이즈가 크면 큰 객체만 인식한다.
해결 아이디어
1. U-shaped architecture
FCN에 skip connection을 추가한 U 형태 아키텍쳐를 구상했다.
FCN 구조는 인코더-디코더 구조를 말한다.
인코더는 입력 이미지의 low level features를 추출한다.
디코더는 입력 이미지의 high level features를 복원한 후 segmentation map을 생성한다.
핵심 아이디어는 skip connection이 추가된 것이다.
skip connection은 low level features를 high level features와 결합한다.
더 정교한 segmentation map을 생성하는데 도움을 주는 것이다.
skip connection으로 인해 context 손실, 위치정보 성능 2가지가 개선된다.
저자는 정보가 잘 전달된다고 설명했지만,
통계학적으로 볼 때는 기울기 소실이 덜 된 값을 전달하여 학습이 더 잘 된 것으로 보인다.
논문 구현
import torch
import torch.nn as nn
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(out_channels)
self.relu2 = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
return x
일단 컨볼루션 레이어가 계속 반복되기 때문에 하나의 클래스로 빌딩 블럭을 만든다.
class UNet(nn.Module):
def __init__(self, in_channels, out_channels):
super(UNet, self).__init__()
self.encoder_conv1 = DoubleConv(in_channels, 64)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder_conv2 = DoubleConv(64, 128)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder_conv3 = DoubleConv(128, 256)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder_conv4 = DoubleConv(256, 512)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder_conv5 = DoubleConv(512, 1024)
self.upconv6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.decoder_conv6 = DoubleConv(1024, 512)
self.upconv7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.decoder_conv7 = DoubleConv(512, 256)
self.upconv8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.decoder_conv8 = DoubleConv(256, 128)
self.upconv9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.decoder_conv9 = DoubleConv(128, 64)
self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
x1 = self.encoder_conv1(x)
x2 = self.encoder_conv2(self.pool1(x1))
x3 = self.encoder_conv3(self.pool2(x2))
x4 = self.encoder_conv4(self.pool3(x3))
x5 = self.encoder_conv5(self.pool4(x4))
x = self.upconv6(x5)
x = torch.cat([x, x4], dim=1)
x = self.decoder_conv6(x)
x = self.upconv7(x)
x = torch.cat([x, x3], dim=1)
x = self.decoder_conv7(x)
x = self.upconv8(x)
x = torch.cat([x, x2], dim=1)
x = self.decoder_conv8(x)
x = self.upconv9(x)
x = torch.cat([x, x1], dim=1)
x = self.decoder_conv9(x)
x = self.final_conv(x)
return x
그리고 그림처럼 인코더, 디코더, 최종 아웃풋 3가지로 나눠서 만들면 된다.
인코더에서는 맥스풀링, 디코더에서는 업 컨볼루션(전치행렬) 하면 된다.
그리고 forward에서 숏컷이 들어올 때마다 cat으로 합쳐준다.
끝.
'논문리뷰' 카테고리의 다른 글
[논문 구현] ViT ImageNet 평가 방법 (0) | 2023.03.28 |
---|---|
[논문 리뷰] Vision Transformer(ViT) 요약, 코드, 구현 (0) | 2023.03.22 |
[논문 리뷰] Inception v1 요약, 코드, 구현 (0) | 2023.03.16 |
[논문 리뷰] VGGNet 요약, 코드, 구현 (0) | 2023.03.15 |
[논문 리뷰] Transformer (Attention Is All You Need) 요약, 코드, 구현 (0) | 2023.03.14 |