본문 바로가기
논문리뷰

[논문 리뷰] UNet 요약, 코드, 구현

by davidlds 2023. 3. 20.
반응형

논문을 상세히 번역하고 한단어씩 해석해주는 포스팅은 많다.

나는 논문을 누구나 알아듣도록 쉽고 간결하게 전달하고자 한다.

 

UNet

U-Net: Convolutional Networks for Biomedical Image Segmentation
RONNEBERGER, Olaf; FISCHER, Philipp; BROX, Thomas. U-net: Convolutional networks for biomedical image segmentation. In: Medical Image Computing and Computer-Assisted Intervention–MICCAI 2015: 18th International Conference, Munich, Germany, October 5-9, 2015, Proceedings, Part III 18. Springer International Publishing, 2015. p. 234-241.
 

 

저자의 의도

데이터 수가 적어도 높은 정확도를 보이는 FCN을 만들고자 했다.

 

기존 문제점

저자들은 생물학에서 CNN을 활용할 때 문제점 때문에 이 네트워크를 구상했다.

생물학에서는 segmentation에 CNN에 활용한다.

위치정보를 잘 보존해야 하는데 일반적인 CNN은 그렇지 않다.

그래서 나온게 FCN.

FCN은 인코더 디코더 구조를 사용해서 비교적 공간정보를 잘 보존한다.

FCN은 2가지 큰 문제점이 있다.

 

1. 느리다.

이미지 인식 + 공간정보 인식을 위해 분할작업을 하는데,

너무 많이 분할하기 때문에 모델이 매우 느리다.

 

2. 패치 사이즈에 따라 인식 사이즈도 바뀐다.

패치 사이즈가 작으면 작은 객체만 인식하고,

패치 사이즈가 크면 큰 객체만 인식한다.

 

해결 아이디어

1. U-shaped architecture

아키텍쳐
아키텍쳐

FCN에 skip connection을 추가한 U 형태 아키텍쳐를 구상했다.

 

FCN 구조는 인코더-디코더 구조를 말한다.

인코더는 입력 이미지의 low level features를 추출한다.

디코더는 입력 이미지의 high level features를 복원한 후 segmentation map을 생성한다.

 

핵심 아이디어는 skip connection이 추가된 것이다.

skip connection은 low level features를 high level features와 결합한다.

더 정교한 segmentation map을 생성하는데 도움을 주는 것이다.

skip connection으로 인해 context 손실, 위치정보 성능 2가지가 개선된다.

 

저자는 정보가 잘 전달된다고 설명했지만,

통계학적으로 볼 때는 기울기 소실이 덜 된 값을 전달하여 학습이 더 잘 된 것으로 보인다.

 

논문 구현

import torch
import torch.nn as nn


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        return x

일단 컨볼루션 레이어가 계속 반복되기 때문에 하나의 클래스로 빌딩 블럭을 만든다.

 

class UNet(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(UNet, self).__init__()
        self.encoder_conv1 = DoubleConv(in_channels, 64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder_conv2 = DoubleConv(64, 128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder_conv3 = DoubleConv(128, 256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder_conv4 = DoubleConv(256, 512)
        self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.encoder_conv5 = DoubleConv(512, 1024)

        self.upconv6 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.decoder_conv6 = DoubleConv(1024, 512)
        self.upconv7 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.decoder_conv7 = DoubleConv(512, 256)
        self.upconv8 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.decoder_conv8 = DoubleConv(256, 128)
        self.upconv9 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.decoder_conv9 = DoubleConv(128, 64)

        self.final_conv = nn.Conv2d(64, out_channels, kernel_size=1)

    def forward(self, x):
        x1 = self.encoder_conv1(x)
        x2 = self.encoder_conv2(self.pool1(x1))
        x3 = self.encoder_conv3(self.pool2(x2))
        x4 = self.encoder_conv4(self.pool3(x3))
        x5 = self.encoder_conv5(self.pool4(x4))

        x = self.upconv6(x5)
        x = torch.cat([x, x4], dim=1)
        x = self.decoder_conv6(x)
        x = self.upconv7(x)
        x = torch.cat([x, x3], dim=1)
        x = self.decoder_conv7(x)
        x = self.upconv8(x)
        x = torch.cat([x, x2], dim=1)
        x = self.decoder_conv8(x)
        x = self.upconv9(x)
        x = torch.cat([x, x1], dim=1)
        x = self.decoder_conv9(x)

        x = self.final_conv(x)
        return x

그리고 그림처럼 인코더, 디코더, 최종 아웃풋 3가지로 나눠서 만들면 된다.

인코더에서는 맥스풀링, 디코더에서는 업 컨볼루션(전치행렬) 하면 된다.

그리고 forward에서 숏컷이 들어올 때마다 cat으로 합쳐준다.

끝.

 

관련 논문 리스트 (스크롤 내려서 Paper List 참고)

반응형