본문 바로가기
논문리뷰

[논문 구현] ViT ImageNet 학습하는 방법

by davidlds 2023. 6. 8.
반응형

ImageNet 데이터셋을 받아야 한다.

(이게 진짜 킹받는다 ^^)

 

ImageNet 오피셜 홈페이지

https://image-net.org/download-images

 

ImageNet

Download ImageNet Data ImageNet does not own the copyright of the images. For researchers and educators who wish to use the images for non-commercial research and/or educational purposes, we can provide access through our site under certain conditions and

image-net.org

선수입장

 

회원가입

회원가입 페이지
회원가입 페이지

이메일은 학교메일(마지막에 ac.kr로 끝나는거) 권장한다.

별표 없는 것도 다 채워야한다.

 

신청버튼이 나오면 클릭하고, 메일가서 인증한다.

이제 다운가능.

이미지 다운로드
이미지 다운로드

여기서 제일 위에 Training Images 받으면 된다.

 

다운로드한 파일(tar 확장자)을 ./data/ImageNet/ 폴더에 넣는다.

./data/ImageNet/ 폴더에서 터미널을 켠다.

 

train 이미지 푸는 코드를 실행한다.

mkdir train && mv ILSVRC2012_img_train.tar train/ && cd train
tar -xvf ILSVRC2012_img_train.tar && rm -f ILSVRC2012_img_train.tar
find . -name "*.tar" | while read NAME ; do mkdir -p "${NAME%.tar}"; tar -xvf "${NAME}" -C "${NAME%.tar}"; rm -f "${NAME}"; done
cd ..

 

train 폴더가 생겼다.

이제 이걸로 신나게 학습시키면 된다.

 

친절하게 데이터로더까지 코드를 쳐줄거다.

transform_train = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
transform_test = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
train_set = torchvision.datasets.ImageFolder('./data/train', transform=transform_train)
train_loader = data.DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
test_set = torchvision.datasets.ImageFolder('./data/val', transform=transform_test)
test_loader = data.DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)

 

 

학습 코드는 대충 이렇게 짜면 된다.

class FineTunner(object):
    def __init__(self):
        self.model = None
        self.optimizer = None
        self.epochs = []
        self.losses = []

    def process(self):
        self.build_model()
        self.finetune_model()
        self.save_model()

    def build_model(self):
        self.model = ViTPooling(image_size=IMAGE_SIZE,
                                patch_size=PATCH_SIZE,
                                in_channels=IN_CHANNELS,
                                num_classes=NUM_CLASSES,
                                embed_dim=EMBED_DIM,
                                depth=DEPTH,
                                num_heads=NUM_HEADS,
                                ).to(device)
        checkpoint = torch.load(pre_model_path)
        self.epochs = checkpoint['epochs']
        self.model.load_state_dict(checkpoint['model'])
        self.losses = checkpoint['losses']
        print(f'Parameter: {sum(p.numel() for p in self.model.parameters() if p.requires_grad)}')
        print(f'Classes: {NUM_CLASSES}')
        print(f'Epoch: {self.epochs[-1]}')

    def finetune_model(self):
        model = self.model
        criterion = nn.CrossEntropyLoss()
        optimizer = Adam(model.parameters(), lr=LEARNING_RATE)

        for epoch in range(NUM_EPOCHS):
            running_loss = 0.0
            saving_loss = 0.0
            for i, data in tqdm(enumerate(train_loader, 0), total=len(train_loader)):
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.to(device)

                optimizer.zero_grad()

                outputs = model(inputs)
                loss = criterion(outputs, labels)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
                saving_loss = loss.item()
                if i % 100 == 99:
                    print(f'[Epoch {epoch}, Batch {i + 1:5d}] loss: {running_loss / 100:.3f}')
                    running_loss = 0.0
            self.epochs.append(epoch + 1)
            self.model = model
            self.optimizer = optimizer
            self.losses.append(saving_loss)
            self.save_model()
        print('****** Finished Pre-training ******')
        self.model = model

    def save_model(self):
        checkpoint = {
            'epochs': self.epochs,
            'model': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'losses': self.losses,
        }
        torch.save(checkpoint, fine_model_path)
        print(f"****** Model checkpoint saved at epochs {self.epochs[-1]} ******")

혹시 validation이나 pre-training이 궁금하다면 아래 포스팅을 보자.

(validation) https://davidlds.tistory.com/14

(pre-training) https://davidlds.tistory.com/24

 

끝.

 

반응형