$$
\definecolor{input}{RGB}{66, 133, 244}
\definecolor{output}{RGB}{219, 68, 55}
\definecolor{dinput}{RGB}{244, 180, 0}
\definecolor{doutput}{RGB}{15, 157, 88}
\definecolor{dweight}{RGB}{102, 0, 255}
$$
逆伝播アルゴリズム
バックプロパゲーション アルゴリズムは、大規模なニューラル ネットワークを迅速にトレーニングするために不可欠です。この記事では、アルゴリズムの仕組みについて説明します。
下にスクロールしてください...
シンプルなニューラル ネットワーク
右側には、入力が 1 つ、出力ノードが 1 つ、2 つのノードの非表示レイヤ 2 つを備えたニューラル ネットワークが表示されています。
隣接するレイヤのノードは、ネットワーク パラメータであるウェイト \(w_{ij}\)で接続されます。
活性化関数
各ノードには、合計入力 \(\color{input}x\)、アクティベーション関数 \(f(\color{input}x\color{black})\)、出力 \(\color{output}y\color{black}=f(\color{input}x\color{black})\)があります。 \(f(\color{input}x\color{black})\) は非線形関数でなければなりません。そうでなければ、ニューラル ネットワークは線形モデルのみを学習できます。
よく使用されるアクティベーション関数は Sigmoid 関数です。
\(f(\color{input}x\color{black}) = \frac{1}{1+e^{-\color{input}x}}\)
誤差関数
目標は、 \(\color{output}y_{output}\)
すべての入力 \(\color{input}x_{input}\)で予測出力が目標に近づくように、データからネットワークの重みを自動的に学習することです。 \(\color{output}y_{output}\)
目標からどのくらい離れているかを測定するには、エラー関数を使用します \(E\)。よく使用されるエラー関数は \(E(\color{output}y_{output}\color{black},\color{output}y_{target}\color{black}) = \frac{1}{2}(\color{output}y_{output}\color{black} - \color{output}y_{target}\color{black})^2 \)です。
転送の伝播
まず、入力例を受け取り、 \((\color{input}x_{input}\color{black},\color{output}y_{target}\color{black})\) ネットワークの入力レイヤを更新します。
一貫性を保つため、入力は他のノードと同様のものですが、活性化関数がないものと仮定して、出力が入力と等しくなるようにします(例: \( \color{output}y_1 \color{black} = \color{input} x_{input} \))。
転送の伝播
ここでは、最初の隠しレイヤを更新します。前のレイヤのノードの出力を取得し、 \(\color{output}y\) 重みを使用して次のレイヤのノードの入力を計算します。 \(\color{input}x\) $$ \color{input} x_j \color{black} = $$$$ \sum_{i\in in(j)} w_{ij}\color{output} y_i\color{black} +b_j$$
転送の伝播
次に、最初の隠れ層のノードの出力を更新します。これには、アクティベーション関数 \( f(x) \)を使用します。$$ \color{output} y \color{black} = f(\color{input} x \color{black})$$
転送の伝播
この 2 つの数式を使用して、ネットワークの残りの部分に伝播し、ネットワークの最終出力を取得します。$$ \color{output} y \color{black} = f(\color{input} x \color{black})$$
$$ \color{input} x_j \color{black} = $$$$ \sum_{i\in in(j)} w_{ij}\color{output} y_i \color{black} + b_j$$
エラーのデリバティブ
逆伝播アルゴリズムは、予測出力を特定の例の目的の出力と比較した後、ネットワークの各重みを更新する量を決定します。このために、各重みに関してエラーがどのように変化するかを計算する必要があります \(\color{dweight}\frac{dE}{dw_{ij}}\)。
エラーの導関数を取得したら、シンプルな更新ルールを使用して重みを更新できます。$$w_{ij} = w_{ij} - \alpha \color{dweight}\frac{dE}{dw_{ij}}$$
ここで、 \(\alpha\) は正の定数です。これは学習率と呼ばれ、経験的に微調整する必要があります。
[注] 更新ルールはきわめてシンプルです。重みが増すとエラーが発生する(\(\color{dweight}\frac{dE}{dw_{ij}}\color{black} < 0\))場合は重みを上げ、重みを上げるとエラーが発生する(\(\color{dweight}\frac{dE}{dw_{ij}} \color{black} > 0\))場合は重みを減らします。
その他のデリバティブ
\(\color{dweight}\frac{dE}{dw_{ij}}\)を計算しやすくするため、ノードごとにさらに 2 つのデリバティブ、つまりエラーによる変化を次のように保存します。
- ノードの合計入力 \(\color{dinput}\frac{dE}{dx}\)
- ノードの出力 \(\color{doutput}\frac{dE}{dy}\)。
逆伝播
エラーのデリバティブの逆伝播を開始します。この特定の入力例の予測出力があるため、その出力によってエラーがどのように変化するかを計算できます。エラー関数から、 \(E = \frac{1}{2}(\color{output}y_{output}\color{black} - \color{output}y_{target}\color{black})^2\) 次のようになります。$$ \color{doutput} \frac{\partial E}{\partial y_{output}} \color{black} = \color{output} y_{output} \color{black} - \color{output} y_{target}$$
逆伝播
これで、チェーンルールを \(\color{doutput} \frac{dE}{dy}\) 使用して \(\color{dinput}\frac{dE}{dx}\) 取得できます。$$\color{dinput} \frac{\partial E}{\partial x} \color{black} = \frac{dy}{dx}\color{doutput}\frac{\partial E}{\partial y} \color{black} = \frac{d}{dx}f(\color{input}x\color{black})\color{doutput}\frac{\partial E}{\partial y}$$
\(\frac{d}{dx}f(\color{input}x\color{black}) = f(\color{input}x\color{black})(1 - f(\color{input}x\color{black}))\) の場合、 \(f(\color{input}x\color{black})\) が Sigmoid アクティベーション関数です。
逆伝播
ノードの合計入力に関する誤差導関数を取得したらすぐに、そのノードに入る重みに関するエラー導関数を取得できます。$$\color{dweight} \frac{\partial E}{\partial w_{ij}} \color{black} = \frac{\partial x_j}{\partial w_{ij}} \color{dinput}\frac{\partial E}{\partial x_j} \color{black} = \color{output}y_i \color{dinput} \frac{\partial E}{\partial x_j}$$
逆伝播
チェーンルールを使うと、前のレイヤからも \(\frac{dE}{dy}\) 取得できます。丸で囲むように作成しました。$$ \color{doutput} \frac{\partial E}{\partial y_i} \color{black} = \sum_{j\in out(i)} \frac{\partial x_j}{\partial y_i} \color{dinput} \frac{\partial E}{\partial x_j} \color{black} = \sum_{j\in out(i)} w_{ij} \color{dinput} \frac{\partial E}{\partial x_j}$$
逆伝播
あとは、すべてのエラーの導関数を計算するまで、前の 3 つの数式を繰り返すだけです。