본문 바로가기
AI/PyTorchZeroToAll by Sung Kim

[PyTorchZeroToAll] 9. Softmax Classifier

by 쵸빙 2020. 1. 27.

     이번 시간에는 신경망에서 폭넓게 쓰이는 softmax classifier에 대해서 알아보도록 하겠다.

 

MNIST dataset

     손으로 쓴 숫자들이 0~9 사이에서 무엇인지를 알아보는 MNIST 데이터셋이다.

 

     이것은 classification 문제이고, 하나의 입력을 받아서 0 또는 1의 하나의 출력을 내는 이미 존재하는 logistic regression model을 사용할 수 있을 것이다. 

 

     그러나 우리는 0~9까지의 10개가 가능하기 때문에 출력이 10개인 것이 더 합당할 것이다. 이것을 어떻게 구현하고 출력을 예측해야할 것인가? 우리는 matrix 곱셈을 이용해서 이것을 계산할 수 있다.

 

     저번 시간에 왼쪽 행렬과 같이 입력이 2개일 때, 가중치 행렬이 2 x 1의 형식이어서, 출력 결과를 올바르게 낼 수 있었다. 그렇다면, 이번 MNIST 데이터셋과 같이 10개의 가능한 출력이 있을 때 우리의 가중치 행렬은 어떤 모양이어야할까? 만약에 입력 데이터의 개수가 m개였다면, 출력은 m x 10의 행렬 형태일 것이다. 

 

     그렇다면 출력이 10개가 있을 때 어떻게 그것들을 확률로 나타내야할까? Softmax 함수를 이용하면 가능하다.

 

Softmax function

     softmax 함수의 식은 위와 같다. x가 주어지고 linear 함수를 통과하면 z가 결정될 것이고, 이것을 scores(logits)라고 한다. 여기서 softmax를 적용하면 각 index에 대하여 probability를 구할 수 있다. 

 

     입력이 주어지면 linear function에 대한 결과를 얻은 다음 softmax 함수를 거친 결과를 1 -hot label로 표현한다. 1 hot label은 하나만 1이고 나머지는 0인 것을 말한다.

 

Cross entropy

     손실 함수로는 cross entropy를 사용한다. cross entropy의 식은 위와 같고, 우리의 예측인 y hat과 실제 값인 y와의 차이를 나타낸다. 그리고 전체 loss는 이 cross entropy들의 합이다.

 

     cross entropy를 실제로 사용한 결과는 위와 같다. 우리는 one hot encoding을 사용하므로 실제 값인 Y가 1, 0, 0의 형태이다. 첫번째 예측인 Y_pred1은 실제 값과 비슷하게 첫번째 원소일 확률이 가장 높게 나왔으므로 cross entropy 결과가 0.35로 작게 나왔고, 두번째 예측인 Y_pred2는 세번째 원소일 확률이 가장 높게 나와서 정답과 거리가 멀으므로 2.30으로 cross entropy 결과가 크게 나왔다.

 

     역시나 PyTorch에서는 cross entropy 함수도 제공하기 때문에 간편하게 가져다 쓰면 된다.

loss 함수를 사용할 때에는 뒤에 오는 실제 값인 Y가 one hot이 아니여야한다는 점에 주의하자.

우리의 예에서는 0 또는 1 또는 2일 것이다. 또 softmax 함수가 CrossEntropyLoss 함수 안에 이미 포함되어있기 때문에 softmax 함수를 거치기 전인 logit 결과를 Y_pred 자리에 넣으면 된다.

기존에 정의된 함수인 CrossEntropyLoss를 사용해도 첫번째 예측의 결과는 0.41로 작고, 두번째 예측의 결과는 1.84로 보다 커서 올바른 결과가 나온 것을 알 수 있다.

 

     CrossEntropyLoss를 사용했을 때 좋은 점은 위와 같이 batch로 multiple prediction을 할 수 있다는 것이다.

첫번째 예측에서는 첫번째 행의 index 2가 0.9로 가장 크고, 두번째 행의 index 0이 1.1로 가장 크고, 세번째 행의 index 1이 2.1로 가장 크므로 올바른 예측을 했다고 볼 수 있다. 반면에, 두번째 예측에서는 첫번째 행의 index 0이 0.8로 가장 크고, 두번째 행의 index 2가 0.5로 가장 크고, 세번째 행의 index 2가 0.4로 가장 크므로 모조리 다 틀렸다.

그렇기 때문에 cross entropy는 첫번째 예측이 0.5로 작게, 두번째 예측이 1.24로 크게 나와서 올바른 결과가 나왔다.

 

     NLLLoss를 softmax classifier에서 어떻게 사용하는지 생각해보자.

 

     MNIST 입력에서는 softmax classifier가 어떻게 사용되는지 알아보자. 위와 같이 28 x 28 픽셀을 가지는 숫자 손글씨 사진에서는 총 784 pixel을 가진다.

 

     MNIST 네트워크에서는 입력 층이 784개이고, 출력 층이 10개가 된다.

 

     중간의 은닉층까지 함께 보자. 은닉층에서는 몇 개의 층을 가질지, 그 층이 몇 개의 입력과 출력을 가질지는 모델을 디자인하는 사람 마음에 달렸다.

 

     위와 같은 에에서는 코드가 위와 같이 나올 것이다. 첫번째 층은 784를 입력으로 받고 520을 출력으로 내고, 다음 층은 520을 입력으로 받고 320을 출력으로 내는 형식이다.

 

 

     forward에서는 [1,28,28] 형태로 받은 데이터를 [-1, 784]로 flatten한다. 그 이후에 각각의 층으로 넘긴다.

마지막 층에서는 활성 함수를 쓰지 않는데, logit을 입력으로 받아서 cross entropy 함수를 쓸 것이기 때문이다.

 

     위와 같이 cross entropy loss 함수를 쓴 것을 train 단계에서 예측에 사용할 것이다.

 

 

    전체 소스코드는 위와 같다. https://github.com/hunkim/PyTorchZeroToAll/blob/master/09_2_softmax_mnist.py

 

hunkim/PyTorchZeroToAll

Simple PyTorch Tutorials Zero to ALL! Contribute to hunkim/PyTorchZeroToAll development by creating an account on GitHub.

github.com

원본 코드는 위의 링크를 따라가기를 바란다.

 

     실행 결과는 위와 같다. 정확도가 97%로 매우 높은 것을 알 수 있다.

label이 여러 개인 예측에서도 cross entropy loss를 사용하면 쉽게 결과를 얻을 수 있다.

 

 

다음 시간에는 CNN에 대해 배우도록 하겠다.