MLエンジニアへの道 #28 - Transformers

Last Edited: 10/26/2024

このブログ記事では、ディープラーニングにおけるTransformerについて紹介します。

ML

これまでの数回の記事では、RNNモデルがコンテキストを記憶する際に抱える課題について、研究者たちが開発したさまざまな解決策や、 並列化が困難であるという重要な欠点を議論してきました。今回の記事では、コンテキストを用いて洗練された単語の埋め込みを作成 する並列化可能な新しいアルゴリズムで今日の深層学習で主流となっているモデル—Transformer—をついに取り上げます!

位置エンコーディング

コンテキストを捉えつつ並列化可能なモデルを作成するためには、シーケンス内の入力の位置を表現することが重要です。 RNNはその逐次性が自然に位置をエンコードしていましたが、通常の密結合層では介入なしで位置をエンコードすることはできません。 Transformerモデルで使用されている位置エンコーディングでは、埋め込みの次元数に応じた正弦波と余弦波を利用し、 各位置でのこれらの波の値を使って位置をエンコードします。

P(k,2i)=sin(kn2im)P(k,2i+1)=cos(kn2im) P(k, 2i) = \text{sin}(\frac{k}{n^{\frac{2i}{m}}}) \\ P(k, 2i+1) = \text{cos}(\frac{k}{n^{\frac{2i}{m}}})

上の数式は位置エンコーディングを適用するための数学的操作を示しています。ここで、mmは埋め込みの次元、kkは入力の位置、 iiは0からm2\frac{m}{2}の範囲で使用する波の番号、P(k,i)P(k, i)はエンコーディング関数を指します。nnはハイパーパラメータ(元の論文ではデフォルトで10,000)で、 波の周期に影響を与えます。これらの波の結果を各kkですべて合計して位置エンコーディングを作成し、埋め込みに追加します。 正弦波と余弦波の値は-1から1の範囲に収まるため、位置エンコーディングが正規化された範囲内に保たれます。

def positional_encoding(seq_len, embed_dim, n=10000):
    P = np.zeros((seq_len, embed_dim))
    for k in range(seq_len):
        for i in np.arange(int(embed_dim/2)):
            denominator = np.power(n, 2*i/embed_dim)
            P[k, 2*i] = np.sin(k/denominator)
            P[k, 2*i+1] = np.cos(k/denominator)
    return P

上のコードは、Saeed, M.(2023年)による位置エンコーディングのPython実装です。この実装では、上記の数式を直接使用して位置エンコーディング行列を作成し、 結果は単語埋め込みに追加できます。位置エンコーディングの方法はこれだけではありませんが、ここでの方法はシンプルで、NLPにおいて広く使用されています。

アテンションメカニズム

位置のエンコードを行った後、いよいよコンテキストに基づいて微妙な単語の埋め込みを作成する段階に入ります。 まずは計算の詳細を一旦置いておいて、概念的にメカニズムを理解しましょう。ここで、アテンション(注意)という概念を導入します。 これは、ある単語が他の単語にどの程度影響を与えるべきかを示す指標です。アテンションを得るために、 クエリキーという概念も導入します。これは、それぞれの単語が関連性の高い単語がどのようなものかを予測(クエリ)し、 他の単語と関連しているか(キー)を表します。特定の単語セットのクエリとキーが似ている場合、その単語がアテンション/注意を払う対象であることを意味します。

アテンションを計算した後、今度はバリューを計算します。これは、他の単語によって実際にどのように単語が修正されるべきかを決定し、 アテンションとバリューで単語の埋め込みを更新します。少し抽象的に感じるかもしれませんので、例を挙げてみましょう。 たとえば、"I confirmed the validity of the document, so I put a seal on it."という文で"seal"を分析すると、 "document"や"official"、"animal"、"marine"といった単語に注目し、単語の表現を適切に調整する必要があります。

この場合、"seal"は"document"や"official"といった単語が関連性の高い単語であると予測し(クエリ)、それらの単語も"seal"との関連性が高いことを期待しています(キー)。 クエリとキーが一致するため、これらの関連性の高い単語に他の無関係な単語よりも多くの注意(アテンション)を払います。 そして、文脈に応じて"seal"は、スタンプや海洋動物のどちらの意味を表すかに応じて単語の表現を調整します(バリュー)。 このように、アテンションをバリューに適用して単語の表現を更新します。注:この説明は概念理解のためであり、実際にクエリ、キー、バリューがどのように学習されるかを必ずしも正確に反映していないかもしれません。

アテンションの計算

計算的には、まずそれぞれのクエリとキー用にサイズ(m,n)(m, n)mmは埋め込みの次元、nnは潜在次元)の共有行列WQW_QおよびWKW_Kを設定し、サイズ(d,m)(d, m)ddは文中の単語数)の単語埋め込みEEをこれらの行列に掛けて、 単語を次元の低いクエリとキーに変換します。これにより、(d,n)(d, n)サイズのQQKKが得られます。クエリとキー間の類似度を計算するためには、 ドット積の類似度(QKTQK^T)を使用します。サイズ(d,d)(d, d)の結果行列がアテンションに対応します。

次に、サイズ(m,m)(m, m)の共有バリュー行列WVW_Vを設定し、単語を変換してサイズ(d,m)(d, m)のバリューVVを生成します。 最終的に、アテンションとバリューを掛けてサイズ(d,m)(d, m)の行列を得て、残差接続により単語の表現をE~\tilde{E}に更新します。 アテンションをバリューに掛ける前に、アテンションをd\sqrt{d}で正規化し、ソフトマックス関数を適用します。

Q,K,V=EWQ,EWK,EWVAttention(Q,K,V)=softmax(QKTd)VE~=E+Attention(Q,K,V) Q, K, V = E W_Q, E W_K, E W_V \\ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d}})V \\ \tilde{E} = E + \text{Attention}(Q, K, V)

しかし、実際には、サイズ(m,m)(m, m)のバリュー行列WVW_Vを直接扱うのは効率的ではありません。特に大きな埋め込み次元の場合、これを避けるため、 MLエンジニアへの道 #24 - GloVeで説明した行列分解を用いて、 サイズ(m,n)(m, n)の2つの行列Wv1W_{v1}Wv2W_{v2}を作成します。これにより、次のような計算式になります。

Q,K,V=EWQ,EWK,EWv1Wv2TAttention(Q,K,V)=softmax(QKTd)VE~=E+Attention(Q,K,V) Q, K, V = E W_Q, E W_K, E W_{v1} W_{v2}^T \\ \text{Attention}(Q, K, V) = \text{softmax}(\frac{QK^T}{\sqrt{d}})V \\ \tilde{E} = E + \text{Attention}(Q, K, V)

この処理は一見複雑に見えるかもしれませんが、実際には一連の行列積と正規化の手順に過ぎず、すべて微分可能で完全に並列化可能です。 このアプローチの欠点としては、固定された単語数ddしか処理できない点が挙げられますが、これはEOS(文の終わり)タグの使用やパディング、 モデルの拡大などで解決可能です。シーケンスの継続予測が必要な場合は、入力をシフトさせ、以前の出力を含めることで対応できます。

class SelfAttention(layers.Layer):
  def __init__(self, embedding_dim, hidden_dim):
    super(SelfAttention, self).__init__()
    self.wq = layers.Dense(hidden_dim)
    self.wk = layers.Dense(hidden_dim)
    self.wv1 = layers.Dense(hidden_dim)
    self.wv2 = layers.Dense(embedding_dim)
 
  def call(self, x):
    Q = self.wq(x)
    K = self.wk(x)
    V = self.wv2(self.wv1(x))
 
    Aw = tf.matmul(Q, K, transpose_b=True)
    d = tf.cast(tf.shape(K)[-1], tf.float32)
    Aw /= tf.math.sqrt(d)
    Aw = tf.nn.softmax(Aw, axis=-1)
    A = tf.matmul(Aw, V)
    return A

上記のコードは、アテンション(自己アテンション)のTensorFlow実装です。クエリ、キー、バリュー行列は密層によって生成され、 アテンションを計算するために操作されます。コードを見れば、アテンションの仕組みが非常にシンプルであることがわかります。 チャレンジとして、このコードをPyTorchで実装してみることをお勧めします。

アテンションの種類

クエリ、キー、バリュー行列が単語間で共有されるため、単一のアテンションモジュール(シングルヘッドアテンション)だけでは埋め込みの更新が十分でない可能性があります。 そこで、複数のアテンションモジュール、いわゆるマルチヘッドアテンションを並行して設定し、シーケンスの異なる側面で更新を行うことが一般的です。

アテンションの概念を理解するために使用した例は自己アテンションの例であり、同じシーケンス内でアテンションが計算されます。一方、別の種類としてクロスアテンションがあり、 異なるシーケンスからキーとバリューを取得してアテンションを適用します。自己アテンションは、文の意味を抽出したり次の単語を予測したりするのに有用であり、 クロスアテンションは抽出された意味を使って異なる言語への翻訳や異なるモダリティ(例:音声の文字起こし)でのシーケンス生成に役立ちます。

トランスフォーマーは、トークナイザー、埋め込み層、位置エンコード、アテンションとフィードフォワード層で構成されるトランスフォーマーレイヤーを利用する並列化可能なモデルアーキテクチャです。 各コンポーネントの実装はタスクの性質によって異なる場合がありますが、多くの場合、トランスフォーマーはマルチヘッド自己アテンションモジュールを使用しています。

なぜTransformerなのか?

ここまで読んで、「なぜ、従来のモデルにポジショナルエンコーディングを追加するだけでなく、固定入力長に対応した新しいモデルを使用する必要があるのか?」 と疑問に思う方もいるかもしれません。従来のFNNやCNNでも同様に固定入力長に対応できるはずで、何かしらの工夫ができたのではないでしょうか?結論から言うと、 実験的に従来のモデルは他の手法、特にRNNよりも精度が劣り、Transformerがより優れていることが示されているからというのが主な理由です。 次に生まれる疑問は、「なぜ従来のモデルではうまくいかず、Transformerではうまくいったのか?」という点です。

その答えは、モデルの帰納バイアス にあるかもしれません。帰納バイアスとは、モデルが未知のデータに対して合理的な予測を学習できるようにするために注入された仮定の集合のことです。 CNNやRNNなどのモデルには、位置の絶対的な意味が問われない「並進不変性」や、局所的な特徴を組み合わせてパターンを捉える「局所受容野」、特徴の階層性、シーケンシャルな依存性、 時間的順序付けなどの仮定が組み込まれています。これらの我々が帰納的に獲得し注入したバイアスは、画像やテキストのような特定のデータタイプに適しているため、 少ない重みで複雑なパターンを迅速に学習できるようになっています。

こうした帰納バイアスは、データセットが小規模で複雑さが中程度の場合には役立ちますが、これらのバイアスに適合しない複雑な関係を学習する際には障害となることがあります。 研究者たちがデータと計算資源へのアクセスを増やし、自然言語のような大きく複雑なデータをより正確にモデリングしようとする中で、 これらのバイアスは逆効果となる可能性がありました。また、FNN単体では、Transformerのように入力によって異なるアクティベーションに「注意」を向けることができません。 十分な数のニューロンを使えば同様の計算を近似できますが、合理的なモデルサイズとデータセットの範囲内で複雑な関係を効率的に学習することはできません。

そのため、Transformerは、現代に多く見られる大規模で複雑なデータに対応するのに適した、適度な帰納バイアスと柔軟性を兼ね備えていると言えます。 一般的な用途で学習済みモデルを特定の用途に活用する転移学習やファインチューニングのような技術や、Transformerを取り巻くツールの急速な発展、 最近の成功例から、近年研究者やエンジニアの注目を集めている理由も理解できます。

結論

本記事では、ポジショナルエンコーディング、アテンションの概念および数学的説明や種類、Transformerがさまざまなタスク(特にNLP)で優れているとされる理由を紹介しました。 TransformerはRNNとは異なり、高い並列性を持ちながらも、効果的に文脈情報を単語埋め込みに反映でき、固定入力長のデメリットを上回る利点を持っています。(少なくとも現時点では。) 次回の記事では、具体的なTransformerアーキテクチャをさらに詳しく解説していきます。

リソース