記事一覧

ニューラルネットワークの誤差逆伝搬の導出

この記事は MCC Advent Calendar 2017 - Adventar の12日目の記事です。
またもや2連続ですね.なんだかSyntaxHighlighterのインデントがおかしいなぁ・ω・;

今回の内容


ニューラルネットワークの学習アルゴリズムのひとつである,誤差逆伝搬を導出します.今回も数学っぽいです.

ニューラルネットワークとは


ニューラルネットワークは,簡単に言うと行列演算の塊です.その行列要素が多数のパラメータになっていて,色々な関数を禁じできる,というものです.
本質的にはフーリエ級数と変わりません.フーリエ級数は周期関数を近似するものですが,
$$
\begin{eqnarray*}
f(t) = a_0 &+& a_1 \cos(\omega t) + a_2 \cos(\omega t) + \cdots \\
&+& b_1 \sin(\omega t) + b_2 \sin(\omega t) + \cdots \\
\end{eqnarray*}
$$
のように多数のパラメータ$a_0, a_1, a2, \cdots, b_1, b_2, \cdots$をうまく調整することで関数近似が実現されます.
同じように,ニューラルネットワークでもパラメータを使って,より複雑な関数を近似するのです.

3層ニューラルネットワーク


今回は,3層で考えます.具体的なネットワーク構成は次のとおりです.

NN_3Layers.png

これは,入力としてベクタ$x$を与え,層1へ入力します.この層で,行列$W_1$を乗算し,バイアスベクタ$b_1$を加算します.具体的に,
$$
\begin{eqnarray*}
u_1 = W_1 \cdot x + b_1
\end{eqnarray*}
$$
というアフィン変換のような計算を行います.最後にこの$u_1$を活性化関数$f_1$に通して,この層の出力$x_2$を算出します.ここで,活性化関数はベクタの成分ごとに適用するもので,
$$
\begin{eqnarray*}
f_k(x) = f_k \Mat{x_1 \\ x_2 \\ \vdots} = \Mat{f_k(x_1) \\ f_k(x_2) \\ \vdots},
\qquad (k=1,2,3)
\end{eqnarray*}
$$
となります.
そして,同じ演算を層2,層3と続けることで,最終的にネットワークの出力$y$を得ます.しかし,この値は所望の出力値(教師データ)$t$とは一致しませんので,誤差$y-t$が発生します.これを,評価関数$J$で評価してやります.この評価関数$J$が最小になるよう,各層の行列$W_k$,バイアスベクタ$b_k$の値をうまく調整してやります(この過程を「学習」と呼ぶんだそうです.また,評価関数は$J=(y-t)^2/2$のような形などが用いられます).

誤差逆伝搬


ネットワークの出力$y$を所望の出力$t$に近づけるためには,パラメータ$W_k$,$b_k$を調整しなくてはなりません.このため,勾配法が使われます.これは,評価関数の微分を減算し,値を更新するというものです.
$$
\begin{eqnarray*}
W_k &-=& r\PDif{J}{W_k}, \\
b_k &-=& r\PDif{J}{b_k}
\end{eqnarray*}
$$
ここで,$r$は学習率という係数で,値の更新の速度を調整するためのものです.大きすぎても解が収束せず,小さすぎてもなかなか更新が進まないので,適切な値を設定しなければなりません.
このアルゴリズムを実行するために,
$$
\begin{eqnarray*}
\PDif{J}{W_1},~\PDif{J}{W_2},~\PDif{J}{W_3},~\PDif{J}{b_1},~\PDif{J}{b_2},~\PDif{J}{b_3}
\end{eqnarray*}
$$
の値を計算しなければなりませんが,
$$
\begin{eqnarray*}
J = J(y, t)
\end{eqnarray*}
$$
であり,さらにネットワーク出力は
$$
\begin{eqnarray*}
y &=& f_3(u_3)
= f_3(W_3 \cdot x_3 + b_3) \\
&=& f_3(W_3 \cdot f_2(u_2) + b_3)
= f_3(W_3 \cdot f_2(W_2 \cdot x_2 + b_2) + b_3)\\
&=& f_3(W_3 \cdot f_2(W_2 \cdot f_1(u_1) + b_2) + b_3)
= f_3(W_3 \cdot f_2(W_2 \cdot f_1(W_1 \cdot x + b_1) + b_2) + b_3)\\
\end{eqnarray*}
$$
ですから,面倒な微分計算になりそうで一筋縄では行きそうにありません.一応,3層目のパラメータ$\P J / \P W_3,~\P J / \P b_3$は計算できそうですのでやってみましょう.まずは$\P J / \P b_3$から.
$$
\begin{eqnarray*}
\PDif{J}{b_3} &=& \PDif{y}{b_3} \cdot \PDif{J}{y}
= \PDif{u_3}{b_3} \cdot \PDif{y}{u_3} \cdot \PDif{J}{y}
= I \cdot \PDif{y}{u_3} \cdot \PDif{J}{y}
= \PDif{y}{u_3} \cdot \PDif{J}{y}
=: \delta_3
\end{eqnarray*}
$$
微分のチェーンルール(約分の逆のように見えるもの)を使っていることに注意してください.また,最後に
$$
\begin{eqnarray*}
\delta_3 := \PDif{y}{u_3}(u_3) \cdot \PDif{J}{y}(y)
\end{eqnarray*}
$$
と定義しました($\P y / \P u_3$の引数は$u_3$,$\P J / \P y$の引数は$y$です).なぜわざわざ記号を新設するのか...?まあ後々便利なのでしばらく我慢してください.
次に$\P J / \P W_3$を計算してみましょう.
$$
\begin{eqnarray*}
\PDif{J}{W_3} &=& \PDif{y}{W_3} \cdot \PDif{J}{y}
= \PDif{u_3}{W_3} \cdot \PDif{y}{u_3} \cdot \PDif{J}{y}\\
&=& \PDif{u_3}{W_3} \cdot \delta_3
= I \otimes x_3^\top \cdot \delta_3
= x_3^\top \otimes \delta_3
\end{eqnarray*}
$$
ここで,前回の$\delta_3 = \P y / \P u_3 \cdot \P J / \P y$を使い,式を簡略化しました.が,$\delta_3$を導入した意図はこれをやるためではありません.次の層の計算で威力を発揮するのです.

2層目の$\P J / \P W_2,~\P J / \P b_2$はどうでしょう.$J$を微分するとしても,$y = f_3( W_3 \cdot f_2(W_2 \cdot x_2 + b_2) + b_3 )$を$b_2, W_2$で微分する必要があり,少々複雑そうです.しかし,微分を書き下してみると戦略が見えてきます.まず,$\P J / \P b_2$を計算します.
$$
\begin{eqnarray*}
\PDif{J}{b_2} &=& \PDif{y}{b_2} \cdot \PDif{J}{y}
= \PDif{u_3}{b_2} \cdot \PDif{y}{u_3} \cdot \PDif{J}{y}
= \PDif{x_3}{b_2} \cdot \PDif{u_3}{x_3} \cdot \PDif{y}{u_3} \cdot \PDif{J}{y} \\
&=& \PDif{u_2}{b_2} \cdot \PDif{x_3}{u_2} \cdot \PDif{u_3}{x_3} \cdot \PDif{y}{u_3} \cdot \PDif{J}{y}
\end{eqnarray*}
$$
ここで,末尾の$\P y / \P u_3 \cdot \P J / \P y$は$\delta_3$に書き換えられることがわかります.
$$
\begin{eqnarray*}
&=& \PDif{u_2}{b_2} \cdot \PDif{x_3}{u_2} \cdot \PDif{u_3}{x_3} \cdot \delta_3
\end{eqnarray*}
$$
さらに,他の微分は
$$
\begin{eqnarray*}
\PDif{u_2}{b_2} &=& \PDif{W_2 x_2 + b_2}{b_2} = I, \\
\PDif{x_3}{u_2} &=& \PDif{f(u_2)}{u_2} = \DIAG{f'(u_2)}, \\
\PDif{u_3}{x_3} &=& \PDif{W_3 x_3 + b_3}{x_3} = W_3^\top.
\end{eqnarray*}
$$
となるので(プライム「'」は微分)
$$
\begin{eqnarray*}
\PDif{J}{b_2} &=& I \cdot \DIAG{f'(u_2)} \cdot W_3^\top \cdot \delta_3 \\
&=& \DIAG{f'(u_2)} \cdot (W_3^\top \cdot \delta_3) \\
\end{eqnarray*}
$$
と,計算できます.これを見ると,3層目での計算結果$\delta_3$を利用して2層目の勾配$\P J / \P b_2$を計算していることがわかります.このような微分のチェーンルールを利用することで,同じ計算を繰り返さず,効率よく計算できることがわかります.
ここで,diagとベクタの積はもう少し簡略化できます.diagの行列は,
$$
\begin{eqnarray*}
\DIAG{f'(u_2)} = \Mat{
f'(u_2^1) & \\
& f'(u_2^2) & \\
& & \ddots
}
\end{eqnarray*}
$$
ですが,これと適当なベクタ$v$との積は
$$
\begin{eqnarray*}
\DIAG{f'(u_2)} = \Mat{
f'(u_2^1) & \\
& f'(u_2^2) & \\
& & \ddots
}
\cdot \Mat{ v_1 \\ v_2 \\ \vdots }
= \Mat{ f'(u_2^1) v_1 \\ f'(u_2^2)v_2 \\ \vdots }
= f'(u_2) \odot v
\end{eqnarray*}
$$
と,成分同士の積に変換できます(アダマール積).よって,$b_2$による勾配は
$$
\begin{eqnarray*}
\PDif{J}{b_2}
= f'(u_2) \odot (W_3^\top \cdot \delta_3)
=: \delta_2
\end{eqnarray*}
$$
と書けます.ここでも$\delta_2$という変数を新設しました.
・・・もうおわかりでしょう.このδが誤差であり,それが順次前の層へ逆伝搬していくために「誤差逆伝搬」よばれます.
つづいて$W_2$による勾配も求めます.
$$
\begin{eqnarray*}
\PDif{J}{W_2} &=& \PDif{y}{W_2} \cdot \PDif{J}{y}
= \PDif{u_3}{W_2} \cdot \PDif{y}{u_3} \cdot \PDif{J}{y}
= \PDif{u_3}{W_2} \cdot \delta_3 \\
&=& \PDif{u_2}{W_2} \cdot \PDif{x_3}{u_2} \cdot \PDif{u_3}{x_3} \cdot \delta_3 \\
&=& \PDif{W_2 \cdot x_2 + b_2}{W_2} \cdot \PDif{f(u_2)}{u_2} \cdot \PDif{W_3 \cdot x_3 + b_3}{x_3} \cdot \delta_3 \\
&=& \PDif{W_2 \cdot x_2 + b_2}{W_2} \cdot \DIAG{f'(u_2)} \cdot (W_3^\top \cdot \delta_3) \\
&=& \PDif{W_2 \cdot x_2 + b_2}{W_2} \cdot \delta_2 \\
&=& (I \otimes x_2)^\top \cdot \delta_2 \\
&=& x_2^\top \otimes \delta_2
\end{eqnarray*}
$$

この調子で3層目の勾配も求めてみましょう.$\P J / \P b_1$を計算します.
$$
\begin{eqnarray*}
\PDif{J}{b_1} &=& \PDif{u_2}{b_1} \cdot \PDif{J}{u_2}
= \PDif{u_2}{b_1} \cdot \delta_2 \\
&=& \PDif{u_1}{b_1} \cdot \PDif{x_2}{u_1} \cdot \PDif{u_2}{x_2} \cdot \delta_2 \\
&=& \PDif{W_1 \cdot x + b_1}{b_1} \cdot \PDif{f(u_1)}{u_1} \cdot \PDif{W_2 \cdot x_2 + b_2}{x_2} \cdot \delta_2 \\
&=& I \cdot \DIAG{f'(u_1)} \cdot W_2^\top \cdot \delta_2 \\
&=& f'(u_1) \odot (W_2^\top \cdot \delta_2)
=: \delta_1
\end{eqnarray*}
$$
$\P J / \P W_1$を計算します.
$$
\begin{eqnarray*}
\PDif{J}{W_1} &=& \PDif{u_1}{W_1} \cdot \PDif{J}{u_1}
= \PDif{u_1}{W_1} \cdot \delta_1 \\
&=& \PDif{W_1 \cdot x + b_1}{W_1} \cdot \delta_1 \\
&=& (I \otimes x)^\top \cdot \delta_1 \\
&=& x^\top \otimes \delta_1
\end{eqnarray*}
$$

まとめ


以上より,誤差逆伝搬に必要な勾配が以下のように求まりました.
$$
\begin{eqnarray*}
\PDif{J}{b_3} &=& \PDif{u_3}{b_3} \cdot \PDif{J}{u_3} = \delta_3,
\qquad \Pt{ \delta_3 = \PDif{J}{u_3} = \PDif{y}{u_3} \cdot \PDif{J}{y} = f'(u_3) \odot \PDif{J}{y} }\\
\PDif{J}{W_3} &=& \PDif{u_3}{W_3} \cdot \PDif{J}{u_3} = x_3^\top \otimes \delta_3. \\
\PDif{J}{b_2} &=& \PDif{u_2}{b_2} \cdot \PDif{J}{u_2} = \delta_2,
\qquad \Pt{ \delta_2 = f'(u_2) \odot (W_3^\top \cdot \delta_3) } \\
\PDif{J}{W_2} &=& \PDif{u_2}{W_2} \cdot \PDif{J}{u_2} = x_2^\top \otimes \delta_2. \\
\PDif{J}{b_1} &=& \PDif{u_1}{b_1} \cdot \PDif{J}{u_1} = \delta_1,
\qquad \Pt{ \delta_1 = f'(u_1) \odot (W_2^\top \cdot \delta_2) } \\
\PDif{J}{W_1} &=& \PDif{u_1}{W_1} \cdot \PDif{J}{u_1} = x_2^\top \otimes \delta_1. \\
\end{eqnarray*}
$$
あとは,これを使って勾配法を計算すればいいです.

Comments

Post a comment

Private comment