Study/Data Science

합성 함수 계산그래프와 연쇄법칙(chain rule) (구 블로그 글 복구)

Railly Linker 2025. 4. 12. 00:23

- 이번 포스팅으론 딥러닝 모델 학습을 이해하기 위해 반드시 이해해야 하는 오차 역전파법을 이해하기 위한 선행 지식을 정리하겠습니다.

 

[합성함수란?]
함수의 실행 단계를 2개 이상으로 나눌수 있는 함수입니다.

예를들면,

z = (x+y)^2

이라는 함수는,
먼저 x+y라는 계산을 먼저 하고,
이후에 제곱을 해야합니다.

병렬적으로 동시에 처리할수 있는 것이 아니고, 순서를 지니죠.

- 즉, 합성함수는, 한 함수의 결과값이 다른 함수의 입력값이 되어 사용되는 함수를 말합니다.
위의 함수에서,

t = x + y

라고 하고,

z = t^2

으로 나누어 표현할수 있습니다.

즉, t 값이 먼저 있고, z값을 찾는 것입니다.

[계산그래프란?]
- 계산 그래프를 쉽게 이해하기 위해 예시를 하나 들어보겠습니다.

이런 계산을 한번 해봅시다.

 

"""
단가가 100원인 사과가 있을 때, 이를 2개 샀습니다.
이 사과의 소비세가 10퍼센트라고 했을 때, 최종 지출 가격은?

"""

위와 같은 간단한 계산이 있다고 하면, 이는 최종적으로 합성함수를 구하는 것입니다.

예제가 단순하므로, 계산을 나눠볼 것도 없지만, 일단 사과 단가와 사과 갯수를 먼저 곱하여 물건 값을 계산해내고, 그 값에 소비세 1.1을 곱하는 것입니다.

그렇게 해야 최종결과가 나오는 것이죠.

이를 계산그래프로 나타내면,

예제 1 계산 그래프

 

위와 같이 됩니다.

조금 앞서나간 설명이 몇개 있는데,
형태만 일단 알아둡시다.

- 계산그래프는, 계산이 실행되는 부분인 노드와,
계산에 사용될 데이터 값들이 지나가는 부분인 에지로 나뉩니다.

위를 보면, 입력값 2개를 먼저 곱하고, 그 결과값인 중간 결과를, 또다른 곱셈 노드에 넣어주며, 이번에는 1.1이라는 소비세를 추가해 곱해줬죠.


그리하여 최종값이 나오게 되는 하나의 사과값 계산 함수가 만들어지는 것입니다.
(프로그래밍에 익숙하신 분이라면, 입력값이 매개변수고, 최종결과가 return값, 그리고 중간 노드부분이, 우리가 함수 내에서 다른 함수를 사용하는 것이라 생각하시면 됩니다.)

- 계산그래프의 순전파와 역전파
계산그래프는 위에 보이는데로 순서가 있기에 방향성을 지닙니다.

입력데이터에서, 앞의 계산 노드를 지나서 계산값을 내는 것 까지를 순전파라 부릅니다.

그런데 역전파는 무엇일까요?


미분을 아신다고 가정하고 설명합니다.
최종 결과가 최종 결과에 미치는 영향은 얼마나 될까요?
전부입니다.

사실, 최종 결과는 말 그대로 결과로, 원인이 될수 없지만, 일단 최종 결과는 자기 자신에 대해 1이라는 기울기를 가집니다.
즉, 1*자기자신 = 자기자신 인 것이죠.

최종 결과에 대한 영향력은, 뒤에서 앞으로 전파됩니다.
앞서 계산노드가 무엇인가에 따라, 해당 값이 결과에 미치는 영향이 달라지죠.

일단 계산 노드에 따른 역전파의 성질은 조금 이따 알아보고,
이 역시 순전파와 마찬가지로, 이전의 영향력이 앞의 영향력에 영향을 미치며 커진다는 것입니다.
(이 말은, 합성함수가 복잡하고 커질수록, 가장 앞부분이 되는 입력값에 가까운 수치들이 최종 결과에 미치는 영향이 커진다는 것을 뜻합니다.)

- 계산그래프를 사용하는 이유 : 전체 함수에 대해 보다 파악하기 쉽고 시각적으로 이해하기 쉬운데, 무엇보다 '국소적 계산'이 가장 큰 이유입니다.

오차 역전파라는 딥러닝 학습에 사용되는 기법을 이해하기 위한 근간이라는 것은, 해당 글에서 밝히도록하고, 그냥, 나눌수 있는 연산 분절의 중간 계산결과를 알수있는 것이라 생각하세요.

- 역전파에 대해 잠깐 확인해보면,
최종결과에 대한 각 부분의 영향력이라 했죠?

한번 실제로 계산해보세요.
예를들면, 사과 값이 110원으로 올라가면 최종 결과가 어떻게 될까요?
이는 우리가 미리 '역전파'로 구한 기울기(영향력)로 쉽게 계산이 가능합니다.
그냥 110에 2.2를 곱해보세요. 242가 나옵니다.

그러면 실제 순전파를 실행하면 해당 값이 나올까요?
110*2*1.1로 순전파를 실행해도 242라는 값이 나오는 것을 알수 있습니다.
자, 최종 결과에서 노드들로 역전파를 실행해서 기울기를 나누어주면 아주 쉽게 각 부분이 최종 결과에 미치는 영향에 대해 알아낼수 있겠죠?

- 위와 같은 이유로, 우리는 최종 결과에 대하여 미분을 수행하는 것보다, 이렇게 역전파를 하는 것이 무척 쉽고 빠른 방법이라는 것을 알수 있습니다.(미분은, 아시다시피, 순전파 계산을, 영향을 미치는 x값의 수에 맞게 반복해줘야합니다. 거기서 나온, 미세하게 변한 y값과 x값의 변화값을 나누어줘서 기울기를 구하는 것이죠.)

- 이러한 법칙이 체인룰입니다.

자, 이제 체인룰, 즉 연쇄법칙에 대해 알아보죠.

[미분법칙]
- 간단하게, 자주 사용되는 미분 법칙을 일단 기록합니다. 참고하세요.


f(x) = 상수  ->  f`(x) = 0
f(x) = e^x  ->  f`(x) = e^x
f(x) = e^-x  ->  f`(x) = -e^-x
f(x) = ax^n  ->  f`(x) = nax^n-1
f(x) = lnx  ->  f`(x) = 1/x

[체인룰이란?]
- 위에서 살펴본, 계산그래프 역전파시, 국소적 미분을 전달하는 원리가 바로 연쇄법칙에 따른 것입니다.


- 연쇄법칙은 간단합니다.
'합성함수의 미분은 합성함수를 구성하는 각 함수의 미분의 곱으로 나타낼수 있다.'
입니다.

- 가장 간단한 함수 f(x)의 계산그래프를 한번 보도록 하죠.

함수 f 계산 그래프

위를 보면, x가 함수 노드 f에 들어가서 y라는 값을 내보냅니다.
합성함수도 아닌 간단한 함수인데, 이를 계산그래프에따라 합성함수라고 생각해봅시다.

f(x) = y 함수를 진행 하고,
이후 y = y*1이라는 함수를 진행했다고 생각해보세요.

이를 역전파 하면, 최종 에지인 y가 최종결과 y에 미치는 영향은 1이겠죠?
미분까지 할것 없이도, y가 1이 늘어나면 최종결과가 1이 늘어나고, 2가 늘어나면 2가 늘어나죠.
미분도 한번 해볼까요?

E(ry/ry)

위에서 기호 라운드를 그냥 r로 표현했습니다.
식으로 치자면, y의 변화에 따라 y에 미치는 영향으로,
저기서의 함수가 y = y이기에,


위에서 적은, 미분법칙에 따라서

 

f(x) = ax^n  ->  f`(x) = nax^n-1

를 적용시켜보면,

y의 1승을 앞으로 두게 되어, y의 0승이 되므로, 곧 미분값이 1이 되는 것입니다.

그리고 노드 f의 이전 에지인 x로 역전파해보면,


E(ry/rx) * E(ry/ry)

로, 체인룰에 따라서, 이전 에지의 미분값이 전파되어 곱해집니다.

E(ry /ry) = 1이니까, 결국은, 위의 함수의 미분은, E(ry/rx)가 되는 것이죠.

(이전 노드의 계산 결과는, 다음 노드의 계산에도 영향을 주니, 곧 노드를 통과하기 전의 에지는, 통과한 후의 에지의 미분값에도 영향을 미치는 것이니 당연한 것이겠죠?)

- 체인룰은, 국소적 계산에 대한 미분들을 구하여 구합니다.
즉, 최종 결과에 대한 미분이 아니라, 계산그래프로 치자면, 바로 뒤 노드의 결과값에 대한 미분값을 구해서 곱합니다.

국소 미분값 역전파

 

위의 계산그래프를 보시면, 국소미분값이 역전파되는 과정을 볼수 있습니다.

1. z에 대한 z의 영향력,
2. z에 대한 t의 영향력 곱하기 z에 대한 z의 영향력
3. t에 대한 x의 영향력 곱하기 z에 대한 t의 영향력 곱하기 z에 대한 z의 영향력

이런 순서대로 역전파가 진행됩니다.
일반적인 계산 법칙에 따라서 계산해보세요.


2번에서는, 분자의 rz와 분모의 rz가 상쇄되면서, rz/rt로, z에 대한 t의 영향력이 되지요?
1번에서도 이처럼 상쇄를 시켜주면 최종적으로 rz/tx가 되어, 최종값 z에 대해 x의 영향력으로, 이로써 체인룰과 역전파가 제대로 동작함을 우리는 증명할수 있습니다.

[노드별 역전파 법칙]
- 위와 같은 원리를 이해한다면, 우리는 각 부분에 실제 미분을 사용할수 있는데,
사실 미분을 직접 전개할 필요가 없습니다.
계산 노드별로 미분의 역전파 법칙이 있기에, 그것을 적용하면 됩니다.

- 대표적인것 몇개만 알아보겠습니다.


(덧셈 노드의 역전파)
- 덧셈을 미분하면 어떻게 될까요?

편미분을 배우셨다면 알수 있을 것입니다.

x + y = z

라는 함수에서,
입력값은 x와 y죠?

rz/rx 값을 구해보죠.

인자값이 2개이므로 편미분을 실행하면(미분 관련해서는 수학 카테고리에 제가 정리한 글이 있습니다.), y값을 특정 값으로 고정해야겠죠?

y가 고정된 상태에서 x값의 변화에 대해서 z에 어떤 영향을 가질까요?

x + 1 = z

라는 함수를 가정했을 때,
x값이 1이 증가하면, z도 1이 증가하고,
x가 2가 증가하면, z도 2가 증가하겠죠?

즉, rz/rx는 1이라는 기울기를 가집니다.
rz/ry 역시 1입니다.

- 위와 같은 결과로 인하여, 덧셈 노드에 대한 역전파를 실행시, 이전 미분값에 1이라는 덧셈 노드의 국소미분값이 곱해지므로, 각 에지에는 이전 미분값이 골고루 전파되는 것입니다.

덧셈 계산 그래프

 

위의 그림을 보면, 이전 에지의 국소미분값이, 각 x, y 에지에 전파됩니다.

원래는 rz/rx * rz/rz인데,

rz/rx가, 위의 설명대로 1이라는 것을 우리는 알고있으므로, 덧셈노드는 이전 미분에 1을 곱하는 것입니다.

(곱셈노드의 역전파)
- 곱셈 함수 z = xy에 대하여 미분을 진행해봅시다.

- 곱셈의 경우는 물건 단가 계산을 생각해보면 됩니다.
x가 개수, y가 단가라고 하면, x가 1개 늘어날때마다 단가 y를 곱하기에, z는 y를 곱한만큼 올라가겠죠?

즉, rz / rx는 y이고,
rz / ry는 x입니다.

직접 미분을 실행하지 않아도, 기울기 공식에 따라서, z = ax이므로, a는 곧 y가 됨을 알수 있습니다.

 

곱셈 계산 그래프

 

위를 보시면, 곱셈 노드가 계산 그래프를 타고 역전파되는 것을 볼수 있습니다.

rz/rx가 y이고, rz/ry는 x라는 것을 이제 알고있으니, 이를 이용하면, 직접 미분을 수행할 것 없이 쉽게 연산이 가능하겠죠?

[마치며]
- 대략 오차역전파를 이해하기 전의 선행지식을 정리해봤습니다.
어떤 함수던, 위와 같이 간단한 노드들의 조합인 것이고, 이러한 함수의 미분값이 체인룰에 따라서 아주 쉽게 알아낼수 있을 것이며,
복잡할 것도 없이 중간 에지들의 기울기값을 알아낼수 있습니다.

- 이러한 오차역전파법을 통해 기울기를 구하는 것은, 최종결과에 따른 미분을 모든 요소에 적용하는 것보다, 후방 에지의 미분값을, 전방 에지에 전달하여, 간단히 곱하는 것으로 기울기를 구할수 있기에, 무척이나 경제적입니다.

(최종결과에 대한 미분은, 최전방 에지에 있더라도, delta값을 더한 값을 최후방까지 순전파 계산을 실행해야합니다. 하지만 이러한 체인룰을 통한 후방 에지의 미분값을 재활용할수 있다면, 후방에서 전방까지 역전파가 진행될수록, 계산해야하는 순전파의 길이는 줄어들게 되는 것입니다.)