이번 시간에는 매우 중요한 개념인 역전파(Back-propagation)에 대해서 다루도록 하겠다.
저번 시간에는 x를 입력으로 받고, y hat을 출력으로 하는 매우 간단한 선형 모델에 대해서 배웠다.
우리는 손실 함수를 가중치로 미분하여 gradient를 계산하는 training 과정을 거쳤다.
저번 시간에는 간단한 네트워크였기 때문에 올바른 가중치를 구하기 위해 일일이 gradient를 계산하는 것이 가능했지만 만약 매우 복잡한 네트워크를 만나게 된다면 이 과정이 너무 오래 걸릴 수 있다.
chain rule을 이용하여 계산 그래프를 이용하면 더 쉽게 계산할 수 있다.
먼저 chain rule에 대해서 알아보겠다. 위에서 설명하듯이 g_(g underscore)함수는 x를 입력으로 받아서 g를 출력으로 내보낸다는 뜻이고, 마찬가지로 f_는 g를 입력으로 받아서 f를 출력으로 내보내는 함수라는 뜻이다.
우리는 최종적으로 f를 x로 미분한 값을 얻고 싶은데, 먼저 f를 g로 미분한 값을 구한 뒤 거기에 g를 x로 미분한 값을 곱하면 그 결과를 얻을 수 있다. 이것은 우리가 g와 x, g와 f의 관계를 미리 알고 있기 때문에 간단하게 구할 수 있다.
그래서 각각의 gradient를 구한 뒤 곱해서 최종 gradient를 알 수 있기 때문에 사슬로 엮여 있는 모습같다고 해서 chain rule이라고 불린다.
BackPropagation에서 chain rule이 어떻게 적용되는지 알아보자. 큰 네트워크의 일부인 f라는 노드가 있다고 하자. 이 f라는 함수는 x와 y라는 입력을 받아서 z라는 출력을 내보내고, 그 출력이 다른 많은 노드들을 거쳐서 결국에는 loss에 다다를 것이다. 맨 끝에서부터 거꾸로 올라오면서 loss 함수를 출력으로 미분한 값을 받아서 현재의 노드의 local gradient를 곱해서 그 결과를 앞으로 넘긴다. loss function을 x, y 각각의 입력값으로 미분한 결과를 알아내기 위해 먼저 loss function을 z로 미분한 결과에 z를 y로 미분한 결과를 곱해서 알아낸다. 그렇다면 forward와 backward 계산에 대해 따로 자세히 알아보자.
먼저 forward propagation의 경우는 간단하게 알 수 있다. 위의 예시에서는 f 함수가 곱셈 연산으로, f = x * y의 형태이기 때문에 x로 f를 미분한 결과는 x * y를 x로 미분한 것이므로 y이고, y로 f를 미분한 결과는 x * y를 y로 미분한 것이므로 x이다.
그다음으로는 backward propagation에 대해 배워보도록 하겠다. 어떻게 손실 함수를 Z로 미분한 결과가 5라고 알게 되었다고 치자. L를 x로 미분한 결과는 L을 z로 미분한 결과에 z를 x로 미분한 결과를 곱한 것이므로 이 경우에는 5 * y, 즉 5 * 3 = 15가 될 것이다. 마찬가지로, L을 y로 미분한 결과는 L을 z로 미분한 결과에 z를 y로 미분한 결과를 곱한 것이므로 이 경우에는 5 * 2 = 10이 될 것이다.
미분값을 computational graph로 그리면 위와 같다.
forward propagation은 위와 같이 이루어질 것이다.
backward propagation은 위와 같이 이루어진다. loss를 s로 미분, y hat으로 미분, w로 미분하는 과정을 지나면 결과가 -2로 나오게 된다.
다른 예시로 위의 문제를 풀어보면 결과가 -4가 된다.
이제는 bias도 더해진 경우를 생각해보자.
pytorch의 좋은 점은 back propagation을 일일이 할 필요 없이 Variable이라는 것을 사용하면 된다는 것이다
위와 같이 l.backward() 함수를 사용하면 back propagation을 수행한다.
loss 함수를 w로 미분한 결과는 w.grad.data에 저장된다. 그 것 중 하나만 사용하고 싶다면 w.grad.data[0] 이런 식으로 사용하면 된다.
직접 하나하나한 것과 파이토치 라이브러리를 사용한 방법은 위와 같이 결과가 거의 같다.
연습문제로 위 문제를 pytorch로 구현해보자.
위의 문제를 직접, 그리고 pytorch 라이브러리를 이용해서 두 가지 방법 모두로 결과를 구해보자.
다음 시간에는 Linear regression을 파이토치로 하는 법을 알아보도록 하겠다.
'AI > PyTorchZeroToAll by Sung Kim' 카테고리의 다른 글
[PyTorchZeroToAll] 6. Logistic Regression (0) | 2020.01.23 |
---|---|
[PyTorchZeroToAll] 5. Linear Regression in the PyTorch way (0) | 2020.01.22 |
[PyTorchZeroToAll] 3. Gradient Descent (0) | 2020.01.21 |
[PytorchZeroToAll] 2. Linear Model (0) | 2020.01.21 |
[PyTorchZeroToAll] 1. Overview (0) | 2020.01.20 |