statsuのblog

愛知のデータサイエンティスト。自分の活動記録。主に機械学習やその周辺に技術について学んだことを記録していく予定。

簡単な2次元問題でWGAN-gpの基礎理解を深める(python, keras)

WGAN-gpについて理解するため、WGAN-gpを簡単な2次元問題に適用し、その挙動を観察してみました。また、GANとの違いを比較しました
前にGANでやったことの続きです。
簡単な2次元問題でGANの基礎理解を深める(python, keras) - st1990のblog

以下の検証に関するコードはgithubにあげてあります。
github.com

1. 本記事の概要

  • 記事の目的
  • WGAN-gpの概要
  • WGAN-gpの自分的な理解
  • 簡単な2次元問題へのWGAN-gpの適用(python、kerasを使った実装)、結果観察

2. 記事の目的

  • WGAN-gpについて勉強したので、自分が理解したことを記録に残す。
  • WGAN-gpを簡単な問題に適用し、実装方法やその特徴を学ぶ。

3. WGAN-gpの概要

こちらのサイトにわかりやすく解説されているので省略します。
今さら聞けないGAN(4) WGAN - Qiita
今さら聞けないGAN (5) WGAN-gpの実装 - Qiita

4. WGAN-gpの自分的な理解

メモレベルですが。

  • GANについて。
    • GANでは、識別モデルが誤判定するように生成モデルを学習させることで、生成データの分布を本物データに近づけていきました。これは、Jensen-Shannon divergenceという指標を使って本物データと生成データの確率分布の差を測り、それを最小化していくというコンセプトに基づいています。
    • Jensen-Shannon divergenceには勾配消失が起こりやすい、モード崩壊が起こるといった問題があります。
  • WGAN-gpについて。
    • WGAN-gpはWasserstein距離によって本物データと生成データの確率分布の差を測り、それを最小化していくというコンセプトに基づいています。
    • discriminatorの役割が、「本物データか偽物データか識別する」というものから「本物データと生成データのWasserstein距離を推定する」に代わっています。
    • Wasserstein距離を使うため、GANとは損失関数が異なります。(損失関数) = (Wasserstein距離) + (勾配の制約)というイメージ。勾配の制約が加わっているのは、Wasserstein距離を使うための制約を満たすため(gp = gradient penalty)。

5. 簡単な2次元問題へのWGAN-gpの適用

簡単な2次元問題へWGAN-gpを適用し、データ生成の様子を観察しました。
mnistの文字生成などがWGAN-gpの導入としてよく紹介されていますが、データ分布という観点での観察が難しいので、ここでは2次元を選びました。

目的

  • 本物データのデータ空間に属するデータを生成モデルが生成していることを目視確認する。
  • 実装方法を学ぶ。

問題設定

  • 本物データに似たデータを生成するGANモデルをつくる。
  • 学習手順は上述のとおり。
  • 本物データは2次元空間上の下図の点
    f:id:st1990:20190616234253p:plain
    本物データの分布

実装

python、kerasを使って実装しました。
github.com

結果と考察

WGA-gpとGANの学習が進む様子を下図に示します。

f:id:st1990:20190616234253p:plain
本物データの分布
f:id:st1990:20190714185341g:plain
WGAN-gpの生成データの学習の様子
f:id:st1990:20190620005836g:plain
GANの生成データの学習の様子

この結果より以下のことがわかります。
- 最初は生成データはランダムであるが、学習が進むにつれて本物データと似たような分布になっていく。
- GANはあまり収束していくように見えないが、WGAN-gpは収束していっている(epochが増えていっても分布の変化が小さい)。
- GANと比べてWGAN-gpはモード崩壊が少ない。
WGA-gp良いですね。

WGAN-gpでは損失関数に勾配の情報を使うと説明しましたが、勾配に関する損失関数をどの程度考慮するかというハイパーパラメータがあります。(損失関数) = (Wasserstein距離) +b * (勾配の制約)のbです。
これを大きくしたときの結果ものせておきます。

f:id:st1990:20190714190146g:plain
WGAN-gpの生成データの学習の様子(損失関数における勾配情報の重み大)
がっつりモード崩壊しています。モード崩壊というか、ちゃんと学習できていないということなのかな?

難しかったこと・工夫したこと・得た知見

  • 活性化関数は、生成モデルと識別モデルの両方でlearkyReluを使いました。GANのときは識別モデルにはReluを使わないとうまくいきませんでしたが、これはGANとWGA-gpで識別モデルの役割が違うため(上述)だと思います。

6. まとめ

  • WGAN-gpの基本的な概念を説明しました。
  • WGAN-gpの概念についての自分の理解・イメージを説明しました。
  • 簡単な2次元問題にWGAN-gpを適用しました。その結果、本物データの分布へ生成データの分布が一致していく様子などを観察でき、GANとの違いも確認できました。