MLエンジニアへの道 #27 - レイヤー正規化

Last Edited: 10/16/2024

このブログ記事では、ディープラーニングにおけるレイヤー正規化について紹介します。

ML

前回の記事でLSTMの構造が勾配消失問題を緩和できることを説明しましたが、依然として勾配爆発問題に悩まされる可能性があることも指摘しました。 MLエンジニアへの道 #15 - 勾配消失・勾配爆発 では、勾配爆発問題に対して勾配クリッピングが有効だと説明しましたが、勾配の不安定性に対して最も強力な技術としてバッチ正規化(batch normalization)が取り上げられていました。

しかし、バッチ内のデータポイントごとにシーケンスの長さが異なる場合、バッチ正規化は機能しません。 バッチ正規化はバッチ内の各アクティベーションの平均と分散を使ってアクティベーションを正規化するためです。 また、バッチサイズが小さいときにもパフォーマンスが低下します。したがって、バッチ正規化は入力や出力の長さが可変であるRNNやLSTMには不向きです。

レイヤー正規化

この問題を解決するために、レイヤー正規化を導入します。レイヤー正規化では、 各データポイントごとの全アクティベーションの平均と分散を使って正規化を行います。 これにより、出力の長さが可変であっても、バッチサイズに依存しない正規化が可能になります。 以下にレイヤー正規化の数式を示します。

h={x1,x2,...,xn}μh=1Ni=1Nxiσh2=1Ni=1N(xiμh)2h^=hμhσh2+ϵhout=γh^+β h = \{x_1, x_2, ..., x_n\} \\ \mu_h = \frac{1}{N}\sum_{i=1}^{N}x_i \\ \sigma_h^2 = \frac{1}{N}\sum_{i=1}^{N}(x_i - \mu_h)^2 \\ \hat{h} = \frac{h - \mu_h}{\sqrt{\sigma_h^2 + \epsilon}} \\ h_{out} = \gamma \hat{h} + \beta

この正規化の手順はバッチ正規化とほぼ同じですが、バッチサイズに依存せず、各データポイントに対して行われる点が異なります。 これにより、各データポイントのシーケンス長が異なっても計算と並列化が可能になります。また、バッチ正規化では、 バッチを使用しない推論時にランニング平均と分散を追跡する必要がありますが、レイヤー正規化では、 学習時・テスト時・推論時で異なる処理を必要としません。以下がレイヤー正規化のTensorFlowの実装例です。

class LayerNormalization(layers.Layer):
  def __init__(self, embed_dim, epsilon=1e-5):
    super(LayerNormalization, self).__init__()
    self.gamma = tf.Variable(tf.ones((1, embed_dim)), trainable=True, name="gamma")
    self.beta = tf.Variable(tf.zeros((1, embed_dim)), trainable=True, name="beta")
    self.epsilon = epsilon
 
  def call(self, x):
    mean = tf.expand_dims(tf.reduce_mean(x, axis=-1), axis=-1)
    var = tf.expand_dims(tf.math.reduce_variance(x, axis=-1), axis=-1)
    normalized = (x - mean) / tf.math.sqrt(var + self.epsilon)
    out = self.gamma * normalized + self.beta
    return out

練習としてPyTorchでも実装してみることをお勧めします。 Ba, L. J. らによる2016年の論文では、特にRNNにおいて長いシーケンスやミニバッチでレイヤー正規化が有効であることが実証されました。 一方で、CNNにおいてはバッチ正規化がレイヤー正規化よりも優れていることが確認されています。しかし高い並列化性能や学習の高速化、 可変長シーケンスへの対応能力のおかげで、レイヤー正規化は、NLPのようなシーケンス処理を伴うタスクにおいて今でも広く使われています。 レイヤー正規化の層は、TensorFlowとPyTorchの両方で、バッチ正規化と同様に事前定義されています。興味があれば、 前回と前々回の記事で作成したRNNやLSTMモデルにレイヤー正規化層を追加してみることをお勧めします。

解決されていない大きな問題

前回の記事を含め、私たちはLSTM、GRU、勾配クリッピング、レイヤー正規化など、 RNNの不安定な勾配問題を解決するためのさまざまな手法に触れました。 また、双方向(bidirectional)やディープRNNといったモデル構造が、 モデルのパフォーマンス向上に寄与することにも触れました。

これらの手法は、モデルのパフォーマンスや学習に対して徐々に改善をもたらしましたが、 RNNは依然として長期記憶の保持や勾配爆発問題に苦しんでいました。モデルのパフォーマンスは頭打ちになり、 自然なテキストを生成できるレベルには達しませんでした。

しかし、最も重要なのは、前回の記事で触れた最大の問題、つまりRNNの並列化の欠如が、 これらの手法では解決されていないということです。どのセルやモデル構造を選んだとしても、 計算と時間逆伝播(BPTT: Back propagation through time)の逐次的な性質は依然として残り、これが並列化の利点を享受することを妨げています。 このため、モデルの学習や推論が非常に遅くなります。最近では、この問題に対するいくつかの解決策が提案されていますが、 依然としてその解決策には多くの不確実性があります。

したがって、次回の記事からは、最近広く使われている並列化可能な代替アプローチについて、いよいよ議論を進めていきます。

リソース