statsuのblog

愛知のデータサイエンティスト。自分の活動記録。主に機械学習やその周辺に技術について学んだことを記録していく予定。

簡単な2次元問題でGANの基礎理解を深める(python, keras)

GANについて理解するため、GANを簡単な2次元問題に適用し、その挙動を観察してみました。実装にはpythonとkerasを使いました。
mnistの文字生成などがGANの導入としてよく紹介されていますが、文字等の画像データはデータ分布形状という観点での観察が難しいので、ここでは2次元を選びました。

以下の検証に関するコードはgithubにあげてあります。
github.com

1. 本記事の概要

  • 記事の目的
  • GANの概要
  • GANの自分的な理解
  • 簡単な2次元問題へのGANの適用(python、kerasを使った実装)、結果観察

2. 記事の目的

  • 今回初めてGANについて勉強したので、自分が理解したことを記録に残す。
  • GANを簡単な問題に適用し、実装方法やその特徴を学ぶ。

3. GANの概要

GANとは

  • 敵対的生成ネットワークのこと。
  • 学習データとして与えたもの以外のもの(絵、画像など)を生成したりするような用途で使われる。
  • 生成モデルと識別モデルで構成される。生成モデルは本物データに似せた偽物データを作る。識別モデルは本物と偽物データを見分ける。両モデルを戦わせながらどちらも高めていくと、最終的に本物そっくりな偽物を作る生成モデルができあがる。
  • モデル構造にはdeep learningが使われる。
  • 上記のような変わった学習方法なので、学習させるのが他のdeep learningと比べて難しい。

GANでできること

GANについてのサイト集

4. GANの自分的な理解

一番基本的なGANについて自分的な理解を記録しておきます。
ここでいうGANとは、本物データに似た偽物データを作るようなGANを指しています。

基本概念と学習のイメージ

f:id:st1990:20190704002251p:plain
ganのイメージ

基本概念

  • GANの目的は、『本物データxが属するデータ空間Xを求め、空間Xに属するデータx'を生成すること』。
  • 生成モデルと識別モデルから構成される。それぞれの役割は以下のとおり。
    • 識別モデル
      • 「あるデータがデータ空間Xに属するか判定する」=「本物データxが属するデータ空間Xを表す」
    • 生成モデル
      • 「乱数からデータ空間Xに属するデータx'を生成する」

学習

  • 学習の手順は以下のとおり。
    • ① 乱数から生成モデルで偽物データを生成する。
    • ② "本物データ"と"①の偽物データ"を正しく識別するよう識別モデルを学習させる。
    • ③ 識別モデルが「偽物データを本物データと誤識別」するよう生成モデルを学習させる。
    • ①~③を繰り返す。
  • 基本概念と学習の関係は以下のとおり。
    • 学習手順②で、どのようなデータが本物データのデータ空間Xに属するかを認識できるようになる。
    • 学習手順③で、データ空間Xに属するデータを生成できるようになる。

5. 簡単な2次元問題へのGANの適用

簡単な2次元問題へGANを適用し、データ生成の様子を観察しました。
mnistの文字生成などがGANの導入としてよく紹介されていますが、データ分布という観点での観察が難しいので、ここでは2次元を選びました。

目的

  • 本物データのデータ空間に属するデータを生成モデルが生成していることを目視確認する。
  • 実装方法を学ぶ。

問題設定

  • 本物データに似たデータを生成するGANモデルをつくる。
  • 学習手順は上述のとおり。
  • 本物データは2次元空間上の下図の点
    f:id:st1990:20190616234253p:plain
    本物データの分布

実装

python、kerasを使って実装しました。
github.com

結果と考察

学習が進む様子、識別モデルの識別領域を下図に示します。

f:id:st1990:20190616234253p:plain
本物データの分布
f:id:st1990:20190620005836g:plain
生成データの学習の様子
f:id:st1990:20190620005957p:plain
識別モデル(学習後)が本物と判定する領域
この結果より以下のことがわかります。

  • 最初は生成データはランダムであるが、学習が進むにつれて本物データと似たような分布になっていく。
  • 学習が進んでも生成データの分布は収束せず、変化し続ける。
  • 生成データの分布に偏りがあり、モード崩壊を起こしている。
  • 識別モデルが本物データの分布を正確に識別できるようにはなっていない。

生成データが本物データの分布に近づいていくことを観察できたのでとりあえず目標達成です。

難しかったこと・工夫したこと・得た知見

  • 基本的に学習が不安定で、同じ条件であっても良い結果になることもあれば変な結果になることもあった。
  • モデル構造を決めるのが難しかった。問題の難易度に比べて、かなり複雑なモデルになっている。表現力が高いモデルを使う必要があるみたい。
  • 生成モデルの入力となる乱数の次元は、問題の複雑さに比べて大きめに設定したほうが本物分布をよく再現できた。問題の特徴からすると乱数2つで十分かと思っていたので意外だった。
  • 活性化関数は、生成モデルではlearkyRelu、識別モデルではReluを使うことでうまく学習できた。逆ではうまく学習できなかった。
  • 生成モデルの出力をtanhで[-1,1]の範囲にすることで学習が安定した。これをしないと生成モデルの出力が好き勝手な値をとってしまい学習が不安定であった。実際の応用でも本物データを[-1,1]に正規化し、生成モデルの出力にtanhを使うのが良さそう。このサイトを参考にしました。
    個人的GANのTipsまとめ(随時更新予定) - Qiita

6. まとめ

  • GANの基本的な概念を説明しました。
  • GANの概念についての自分の理解・イメージを説明しました。
  • 簡単な2次元問題にGANを適用しました。その結果、本物データの分布へ生成データの分布が一致していく様子などを観察できました。