簡単な2次元問題でWGAN-gpの基礎理解を深める(python, keras)
WGAN-gpについて理解するため、WGAN-gpを簡単な2次元問題に適用し、その挙動を観察してみました。また、GANとの違いを比較しました
前にGANでやったことの続きです。
簡単な2次元問題でGANの基礎理解を深める(python, keras) - st1990のblog
以下の検証に関するコードはgithubにあげてあります。
github.com
1. 本記事の概要
- 記事の目的
- WGAN-gpの概要
- WGAN-gpの自分的な理解
- 簡単な2次元問題へのWGAN-gpの適用(python、kerasを使った実装)、結果観察
2. 記事の目的
- WGAN-gpについて勉強したので、自分が理解したことを記録に残す。
- WGAN-gpを簡単な問題に適用し、実装方法やその特徴を学ぶ。
3. WGAN-gpの概要
こちらのサイトにわかりやすく解説されているので省略します。
今さら聞けないGAN(4) WGAN - Qiita
今さら聞けないGAN (5) WGAN-gpの実装 - Qiita
4. WGAN-gpの自分的な理解
メモレベルですが。
- GANについて。
- GANでは、識別モデルが誤判定するように生成モデルを学習させることで、生成データの分布を本物データに近づけていきました。これは、Jensen-Shannon divergenceという指標を使って本物データと生成データの確率分布の差を測り、それを最小化していくというコンセプトに基づいています。
- Jensen-Shannon divergenceには勾配消失が起こりやすい、モード崩壊が起こるといった問題があります。
- GANでは、識別モデルが誤判定するように生成モデルを学習させることで、生成データの分布を本物データに近づけていきました。これは、Jensen-Shannon divergenceという指標を使って本物データと生成データの確率分布の差を測り、それを最小化していくというコンセプトに基づいています。
- WGAN-gpについて。
- WGAN-gpはWasserstein距離によって本物データと生成データの確率分布の差を測り、それを最小化していくというコンセプトに基づいています。
- discriminatorの役割が、「本物データか偽物データか識別する」というものから「本物データと生成データのWasserstein距離を推定する」に代わっています。
- Wasserstein距離を使うため、GANとは損失関数が異なります。(損失関数) = (Wasserstein距離) + (勾配の制約)というイメージ。勾配の制約が加わっているのは、Wasserstein距離を使うための制約を満たすため(gp = gradient penalty)。
- WGAN-gpはWasserstein距離によって本物データと生成データの確率分布の差を測り、それを最小化していくというコンセプトに基づいています。
5. 簡単な2次元問題へのWGAN-gpの適用
簡単な2次元問題へWGAN-gpを適用し、データ生成の様子を観察しました。
mnistの文字生成などがWGAN-gpの導入としてよく紹介されていますが、データ分布という観点での観察が難しいので、ここでは2次元を選びました。
目的
- 本物データのデータ空間に属するデータを生成モデルが生成していることを目視確認する。
- 実装方法を学ぶ。
問題設定
- 本物データに似たデータを生成するGANモデルをつくる。
- 学習手順は上述のとおり。
- 本物データは2次元空間上の下図の点
実装
python、kerasを使って実装しました。
github.com
結果と考察
WGA-gpとGANの学習が進む様子を下図に示します。
この結果より以下のことがわかります。
- 最初は生成データはランダムであるが、学習が進むにつれて本物データと似たような分布になっていく。
- GANはあまり収束していくように見えないが、WGAN-gpは収束していっている(epochが増えていっても分布の変化が小さい)。
- GANと比べてWGAN-gpはモード崩壊が少ない。
WGA-gp良いですね。
WGAN-gpでは損失関数に勾配の情報を使うと説明しましたが、勾配に関する損失関数をどの程度考慮するかというハイパーパラメータがあります。(損失関数) = (Wasserstein距離) +b * (勾配の制約)のbです。
これを大きくしたときの結果ものせておきます。
がっつりモード崩壊しています。モード崩壊というか、ちゃんと学習できていないということなのかな?
難しかったこと・工夫したこと・得た知見
- 活性化関数は、生成モデルと識別モデルの両方でlearkyReluを使いました。GANのときは識別モデルにはReluを使わないとうまくいきませんでしたが、これはGANとWGA-gpで識別モデルの役割が違うため(上述)だと思います。
6. まとめ
- WGAN-gpの基本的な概念を説明しました。
- WGAN-gpの概念についての自分の理解・イメージを説明しました。
- 簡単な2次元問題にWGAN-gpを適用しました。その結果、本物データの分布へ生成データの分布が一致していく様子などを観察でき、GANとの違いも確認できました。