본문 바로가기
Error 잡기

Transformer confidence score 보는 방법

by davidlds 2023. 5. 9.
반응형

Transformer를 평가할 때 종종 confidence score가 사용된다.

confidence score는 해당 클래스에 얼마나 확신이 있는지다.

즉 모델이 얼마나 자신있게 '이건 닭이야!!!!!!!!!!1' 라고 하는거다.

이 confidence score를 뽑아내는 방법을 알아보자.

 

이미지넷을 사용할껀데

Cifar-10같은 다른걸로 해도 무방하다.

나처럼 이미지넷으로 할거면 일단 다운부터 받자.

 

 

[논문 구현] ViT ImageNet 평가 방법 | pytorch, timm 라이브러리, timm ViT

ViT는 트랜스포머 중에서 그나마 간단한 형태이다. 실제로 구현하는게 그리 어렵지는 않다. 하지만.......... 논문에서 '대용량 pre-training'이 안된 ViT는 퍼포먼스가 상당히 떨어진다고 나온다. 다시

davidlds.tistory.com

 

코드 복붙

from matplotlib import pyplot as plt
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import DataLoader
from torchvision import datasets
from tqdm import tqdm

from CV.util import imagenet_ind2str

device = 'mps'
# device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
BATCH_SIZE = 1
NUM_WORKERS = 2

transform_test = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_set = datasets.ImageFolder('./data/ImageNet/val', transform=transform_test)
test_loader = DataLoader(test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

def show_conf(n):
    model = timm.models.vit_base_patch16_224(pretrained=True)
    print(f'Parameter: {sum(p.numel() for p in model.parameters() if p.requires_grad)}')
    print(f'Classes: {model.num_classes}')
    print(f'****** Model Creating Completed. ******')
    model.to(device).eval()
    with torch.no_grad():
        for idx, (images, labels) in enumerate(test_loader):
            if idx == n:
                images = images.to(device)
                labels = labels.to(device)
                outputs = model(images)

                _, pred = torch.max(outputs, 1)
                probs = torch.nn.functional.softmax(outputs, dim=1)[0]
                conf = probs[pred].to('cpu')

                print(f'Label : {imagenet_ind2str(int(labels))}')
                print(f'Predict : {imagenet_ind2str(int(pred))}')
                print(f'Confidence : {float(conf):.3f}')
                break

디바이스는 맥북이면 mps, 엔비디아면 cuda 쓰면 된다.

n번째 이미지의 confidence를 보기 위해 이렇게 했다.

 

추가로 설명을 하자면 probs에서 outputs 텐서를 확률값으로 고친다.

그리고 가장 큰 값인 pred 번째 값만 출력한다.

 

아래 이미지에 대한 결과.

이미지
이미지
컨피던스
컨피던스

 

끝.

반응형