MLエンジニアへの道 #21 - 転置畳み込み

Last Edited: 9/21/2024

このブログ記事では、ディープラーニングにおける転置畳み込みについて紹介します。

ML

前回の記事では、CNN とその次元拡張に関する問題について取り上げました。今回は、その問題に対する解決策である 転置畳み込み(Transposed Convolution) について説明します。

転置畳み込み

カーネル畳み込みでは、ピクセル値とカーネル値の線形結合を計算します。しかし、この方法ではパディングが適用されない限り、次元を拡張することはできません。 代わりに、ピクセル値とカーネル値を要素ごとに乗算し、次のように結果を組み合わせることができます。

TConv

上記の例では、2x2 の画像(左)に2x2 のカーネル(右)で転置畳み込みを適用しています。これにより、 転置畳み込みによって画像の次元が2x2 から 3x3に拡張できることが確認できます。しかし、 前回の記事で次元を拡張するもう一つの隠れた操作がありました。それは、入力活性化の逆伝播の際に行われた、 回転されたカーネルを使った完全な畳み込みです。

FullConv

上記の図は、回転されたカーネルを使用した完全畳み込みを示しています。実際、この次元拡張のための順伝播にも使用される畳み込みは、 転置畳み込みと同じ操作*で、次元拡張に使用できます。(つまり、畳み込みは入力に対する勾配を求める際に転置畳み込みを使用しているということです。)

逆伝播

カーネルの重みを学習するためには、損失関数のカーネルの重みに対する偏微分と入力特徴に対する偏微分を計算し、 それを使ってさらに逆伝播を行う必要があります。まず、転置畳み込み(Transposed Convolution)の操作を数学的に表現しましょう。

O=XF O = X \circledast F

ここで、OO は転置畳み込みの出力、XX は入力特徴、\circledastは転置畳み込み操作のシンボル、そして FF はフィルターやカーネルを表しています。 上述のように転置畳み込みを適用すると、次のように計算が行われます。

O1,1=X1,1F1,1O1,2=X1,2F1,1+X1,1F1,2O1,3=X1,2F1,2O2,1=X1,1F2,1+X2,1F1,1O2,2=X1,1F2,2+X1,2F2,1+X2,1F1,2+X2,2F1,1O2,3=X1,2F2,2+X2,2F1,2O3,1=X2,1F2,1O3,2=X2,1F2,2+X2,2F2,1O3,3=X2,2F2,2 O_{1,1} = X_{1,1} F_{1,1} \\ O_{1,2} = X_{1,2} F_{1,1} + X_{1,1} F_{1,2} \\ O_{1,3} = X_{1,2} F_{1,2} \\ O_{2,1} = X_{1,1} F_{2,1} + X_{2,1} F_{1,1} \\ O_{2,2} = X_{1,1} F_{2,2} + X_{1,2} F_{2,1} + X_{2,1} F_{1,2} + X_{2,2} F_{1,1} \\ O_{2,3} = X_{1,2} F_{2,2} + X_{2,2} F_{1,2} \\ O_{3,1} = X_{2,1} F_{2,1} \\ O_{3,2} = X_{2,1} F_{2,2} + X_{2,2} F_{2,1} \\ O_{3,3} = X_{2,2} F_{2,2}

まず、カーネルの重みに対する損失勾配を計算しましょう。これは次のように表現できます。

LFi=k=1MLOkOkFi \frac{\partial L}{\partial F_i} = \sum_{k=1}^{M} \frac{\partial L}{\partial O_k} \frac{\partial O_k}{\partial F_i}

この式をF1,1F_{1,1}に対して展開すると、次のようになります。

LF1,1=LO1,1O1,1F1,1+LO1,2O1,2F1,1+LO2,1O2,1F1,1+LO2,2O2,2F1,1 \frac{\partial L}{\partial F_{1,1}} = \frac{\partial L}{\partial O_{1,1}} \frac{\partial O_{1,1}}{\partial F_{1,1}} + \frac{\partial L}{\partial O_{1,2}} \frac{\partial O_{1,2}}{\partial F_{1,1}} + \frac{\partial L}{\partial O_{2,1}} \frac{\partial O_{2,1}}{\partial F_{1,1}} + \frac{\partial L}{\partial O_{2,2}} \frac{\partial O_{2,2}}{\partial F_{1,1}}

XXFFを単に掛け算しているため、OF1,1\frac{\partial O}{\partial F_{1,1}}の偏微分は対応するXXに等しくなります。したがって、F1,1F_{1,1}に対する微分を以下のように書き換えることができます。

LF1,1=LO1,1X1,1+LO1,2X1,2+LO2,1X2,1+LO2,2X2,2 \frac{\partial L}{\partial F_{1,1}} = \frac{\partial L}{\partial O_{1,1}} X_{1,1} + \frac{\partial L}{\partial O_{1,2}} X_{1,2} + \frac{\partial L}{\partial O_{2,1}} X_{2,1} + \frac{\partial L}{\partial O_{2,2}} X_{2,2}

これはすべてのフィルタ値FFに適用されます。したがって、カーネルの重みに対する損失関数の偏微分は、XX と出力に対する損失勾配の畳み込みで表されます。

LF=XLO \frac{\partial L}{\partial F} = X * \frac{\partial L}{\partial O}

次に、入力特徴 XX に対する損失勾配を計算します。これは次のように表現できます。

LXi=k=1MLOkOkXi \frac{\partial L}{\partial X_i} = \sum_{k=1}^{M} \frac{\partial L}{\partial O_k} \frac{\partial O_k}{\partial X_i}

次に、XX 値に対してこれを展開しましょう。

LX1,1=LO1,1F1,1+LO1,2F1,2+LO2,1F2,1+LO2,2F2,2LX1,2=LO1,2F1,1+LO1,3F1,2+LO2,2F2,1+LO2,3F2,2LX2,1=LO2,1F1,1+LO2,2F1,2+LO3,1F2,1+LO3,2F2,2LX2,2=LO2,2F1,1+LO2,3F1,2+LO3,2F2,1+LO3,3F2,2 \frac{\partial L}{\partial X_{1,1}} = \frac{\partial L}{\partial O_{1,1}} F_{1,1} + \frac{\partial L}{\partial O_{1,2}} F_{1,2} + \frac{\partial L}{\partial O_{2,1}} F_{2,1} + \frac{\partial L}{\partial O_{2,2}} F_{2,2} \\ \frac{\partial L}{\partial X_{1,2}} = \frac{\partial L}{\partial O_{1,2}} F_{1,1} + \frac{\partial L}{\partial O_{1,3}} F_{1,2} + \frac{\partial L}{\partial O_{2,2}} F_{2,1} + \frac{\partial L}{\partial O_{2,3}} F_{2,2} \\ \frac{\partial L}{\partial X_{2,1}} = \frac{\partial L}{\partial O_{2,1}} F_{1,1} + \frac{\partial L}{\partial O_{2,2}} F_{1,2} + \frac{\partial L}{\partial O_{3,1}} F_{2,1} + \frac{\partial L}{\partial O_{3,2}} F_{2,2} \\ \frac{\partial L}{\partial X_{2,2}} = \frac{\partial L}{\partial O_{2,2}} F_{1,1} + \frac{\partial L}{\partial O_{2,3}} F_{1,2} + \frac{\partial L}{\partial O_{3,2}} F_{2,1} + \frac{\partial L}{\partial O_{3,3}} F_{2,2}

上記から、入力値に対する損失関数の偏微分も、LO\frac{\partial L}{\partial O}とフィルターFFの畳み込みで表現されることがわかります。

LX=LOF \frac{\partial L}{\partial X} = \frac{\partial L}{\partial O} * F

ここで観察できるのは、畳み込みと転置畳み込みのバックワードパス(入力に対して)とフォワードパスが逆転しているということです。 (畳み込みではバックワードパスに転置畳み込みを使用し、転置畳み込みではフォワードパスに畳み込みを使用します。)これが、 転置畳み込みが畳み込みの「逆」、逆畳み込みとしてしばしば言及される理由です。

コードの実装

今や畳み込みを使って次元を拡張する方法を見つけたので、畳み込みを使用してオートエンコーダーやGANのようなアーキテクチャを構築できるようになりました。 本記事では、MNISTデータセットを使ってDCGAN(Deep Convolutional Generative Adversarial Network)を構築します。 前回の記事でステップ1と2(データ探索と前処理)を既にカバーしているため、今回はモデル構築に直接取り組みます。

注意: DCGANは範囲が-1から1のtanh活性化関数を使用するため、zscoreの代わりに以下の関数でデータを正規化します。

def min_max(x, axis=None):
    min = x.min(axis=axis, keepdims=True)
    max = x.max(axis=axis, keepdims=True)
    result = 2 * (x - min) / (max - min) - 1
    return result

ステップ 3. モデル

以下は、TensorFlowとPyTorchでのDCGANの実装例です。

畳み込みを使用しているにもかかわらず、トレーニングにかなりの時間がかかることに気づくかもしれません。 これは、おそらく、より大きな潜在次元が非常に大きな次元に投影されるためです。(バッチ正規化も導入しています。) したがって、前回構築したGANと結果を直接比較するのは公平ではありません。トレーニングを高速化するには、 GPUを使用することをお勧めします。

ステップ 4. モデルの評価

モデルをトレーニングした後、DCGANのジェネレーターを取り出し、適切なサイズのノイズを入力して新しい画像を生成することができます。 PyTorchで実装されたDCGANを30エポックでトレーニングした後の画像を見てみましょう。

DCGAN

わずか30エポックのトレーニングでも、手書きの数字に似た、はるかに鮮明な画像がすでに見られます。挑戦として、異なるアーキテクチャを持つGANを構築してみてください。

次元計算のヒント

転置畳み込み層に慣れていない場合、特定のカーネルサイズ、ストライド、パディングが使用されたときの層の出力次元に混乱するかもしれません。 そのような場合は、次の式を使用して出力次元を計算できます。

Dout=(Din1)s2p+k D_{out} = (D_{in}-1)s - 2p + k

ここで、DoutD_{out} は転置畳み込み後の出力次元、DinD_{in} は転置畳み込み前の入力次元、pp はパディング、kk はカーネルサイズ、ss はストライドです。

結論

前回と今回の2つの記事で、新しい層である畳み込み層と転置畳み込み層について説明しました。これらの層を使用することで、 予測器、分類器、特徴抽出器、生成モデルを作成できます。モデルは効率と品質の面で大幅に向上しましたが、 小さな画像であっても現実的なスピードで適切な品質にトレーニングするには、まだGPUが必要です。 大きな画像をトレーニングしたい場合、さまざまな技術と非常に高いハードウェアリソースが必要になります。 そこで、次回の記事では、大規模なモデルをトレーニングするために他の人々が培ってきた技術を上手く活用する方法について説明します。

リソース