statsuのblog

愛知のデータサイエンティスト。自分の活動記録。主に機械学習やその周辺に技術について学んだことを記録していく予定。

Deep EnsemblesでDeep Learningの不確かさを評価する

前回の記事に引き続き、Deep learningの不確かさ評価についてです。
今回は、「Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles」という論文で紹介されているアンサンブルで不確かさを評価する手法の検証を実施しました。

検証コードはgithubgithubにあげてあります。
github.com

以下、要点のみメモ。

アンサンブル学習での不確かさ評価

論文及び参考サイト

手法

  • 目的
    • DLの推定値の不確かさを定量評価する。
  • 手法の概要
    • 通常のDLでは推定値を出力する。例えば、入力情報から土地価格を推定したければDLが土地価格を出力するようにネットワークを作る。
    • 提案されているDLでは、推定値の平均と分散(または推定値の確率分布のパラメータ)を出力するようにネットワークを作る。最小化する損失関数は、推定値の確率分布の対数尤度である。※損失関数は推定値、分散、推定値の正解データから成り、分散の正解データは不要。
    • ブーストラップサンプリングで元データセットからM個のデータセットを作成し、それぞれのデータセットを使ってM個のネットワークを学習する。
    • M個のネットワークの推定値の平均値を最終的な推定値とする。
    • M個のネットワークの混合分布の分散を最終的な分散とする。

kerasでの実装例

ラベル0 or 1の二値分類を対象としたネットワークを作った実装例を紹介します。

問題設定は以下のとおりです。

  • ネットワークの推定値μは、ラベルxの平均値であり0から1の間の値。
  • ラベルxの確率分布が、定義範囲を限定した正規分布p(x : μ, σ^2) (0≦x≦1)に従うと仮定する。このσ^2もネットワークで出力する。
  • 損失関数は、定義範囲を限定した正規分布p(x : μ, σ^2) (0≦x≦1)の対数尤度関数。

kerasでこれらを実装するためには、(1)出力が2種類のネットワーク作成、(2)自作の損失関数の使用、が必要となります。以下、実装例です。

(1) 出力が2種類のネットワーク
推定値μはsigmoid関数、σ^2はsoftplus関数で計算します。σ^2が0に近づきすぎると数値的に不安定になるため、1e-6を足しています。

def built_model(self, input_shape=None):
        """
        model input: image shape(32, 32, 3)
        model output: probability of being label1, uncertainty score.
                      Range is [0,1] and [0,1], respectively. 
        """
        # constants
        if input_shape is None:
            # assume cifar10 image
            input_shape = (32, 32, 3)

        # model structure
        input_img = Input(input_shape)
        h = input_img

        h = Conv2D(32, (3, 3), padding='same')(h)
        h = BatchNormalization()(h)
        h = Activation('relu')(h)
        h = Conv2D(32, (3, 3))(h)
        h = BatchNormalization()(h)
        h = Activation('relu')(h)
        h = MaxPooling2D(pool_size=(2, 2))(h)

        h = Conv2D(64, (3, 3), padding='same')(h)
        h = BatchNormalization()(h)
        h = Activation('relu')(h)
        h = Conv2D(64, (3, 3))(h)
        h = BatchNormalization()(h)
        h = Activation('relu')(h)
        h = MaxPooling2D(pool_size=(2, 2))(h)

        oup_cnn = Flatten()(h)        
        oup_cnn = Dense(32)(oup_cnn)
        oup_cnn = BatchNormalization()(oup_cnn)
        oup_cnn = Activation('relu')(oup_cnn)
        
        # expec
        h_expec = Dense(1)(oup_cnn)
        h_expec = Activation('sigmoid')(h_expec)

        # var
        h_var = Dense(1)(oup_cnn)
        h_var = Activation('softplus')(h_var)
        h_var = Lambda(lambda x: x + 1e-6, output_shape=(1,))(h_var)

        oup = Concatenate(axis=-1)([h_expec, h_var])

        # model
        self.model = Model(inputs=input_img, outputs=oup)

        self.model.summary()

        return

(2) 自作の損失関数の使用
損失関数を定義します。ここでは定義範囲を区切った正規分布を使っているので、規格化係数([0, 1]の範囲での積分値を1にする)の計算があってごちゃごちゃしています。また、σ^2が大きい値を取りやすかったので正則化項loss_reg_varを加えています。

def part_norm_dist_log_likelihood(self, y_true, y_pred):
        """
        expec = y_pred[:,0]
        var = y_pred[:,1]

        -ln(L) = ln(I) + 0.5 * ln(var) + 0.5 * (x - expec)^2 / var

        """
        expec = y_pred[:,0:1]
        var = y_pred[:,1:2]

        loss_var = 0.5 * K.log(var)
        loss_l2 = 0.5 * K.square(y_true - expec) / var

        I = 0.5 * (tf.math.erf((1.0 - expec) / K.sqrt(2.0 * var)) - tf.math.erf((0.0 - expec) / K.sqrt(2.0 * var)))
        loss_I = K.log(I)

        loss_reg_var = K.sqrt(var) * 16.0 # regularization of var

        loss = loss_I + loss_var + loss_l2 + loss_reg_var

        return loss

モデルの損失関数として自作損失関数を以下の方法で設定します。

self.model.compile(loss=self.part_norm_dist_log_likelihood, optimizer='nadam', metrics=['accuracy'])

検証

cifar10の犬猫画像の二値分類を対象として、上記の手法で不確かさ評価を実施してみました。
検証の設定は以下のとおりです。

  • 対象は犬猫画像の二値分類問題。
    • 画像データとしてcifar10を使用。
    • ラベルは犬が1、猫が0。
  • 上記手法で推定値の不確かさを評価。
    • 推定値、分散、損失関数の設定は上述の「kerasでの実装例」のとおり。
    • アンサンブルの数は10個。
    • 学習に使っていないラベル(飛行機、自動車、鳥、鹿、カエル、馬、船、トラック)についても評価を実施。

※以下のグラフのuncertainty coefは、アンサンブルで求めた標準偏差です。

犬猫分類の不確かさ評価結果

犬猫の分類結果を以下に示します。推定値が1のとき犬、0のとき猫です。
推定値が0.5に近いほど推定値の標準偏差(不確かさ)が大きい傾向がわかります。

f:id:st1990:20190815195109p:plain
推定値の標準偏差ヒストグラム
f:id:st1990:20190815195201p:plain
推定値の標準偏差 vs 推定値 (学習データ)
f:id:st1990:20190815195242p:plain
推定値の標準偏差 vs 推定値 (テストデータ)
f:id:st1990:20190815195747p:plain
推定値の標準偏差 vs 推定値 (犬)
f:id:st1990:20190815195819p:plain
推定値の標準偏差 vs 推定値 (猫)

テストデータのROC曲線を下図に示します。
「threshold=y」ラベルは、犬か猫か判定する閾値を単純に推定値としてROC曲線を描いたものです。推定値>閾値であれば犬と判定されます。
「threshold=0.5+a * std」ラベルは、閾値を0.5+astd (-∞<a<∞)としてROC曲線を描いたものです。推定値>0.5+astdであれば犬と判定されます。推定値が0.5に対して不確かさを含めてどの程度余裕があるかを表しています。
結果を見ると…ほぼ同じですね。

f:id:st1990:20190816024535p:plain
ROC曲線(テストデータ)

学習に使っていないラベルでの不確かさ評価結果

学習に使っていないラベル(飛行機、自動車、鳥、鹿、カエル、馬、船、トラック)についても不確かさ評価を実施しました。
推定値の標準偏差ヒストグラム、推定値の標準偏差に対する推定値のグラフを以下に示します。
0.5付近の推定値が増えるので、ヒストグラムのピークとなる標準偏差の値が大きくなったことがわかります。しかし、標準偏差vs推定値の分布は犬猫画像と変わりなく見えます。学習に使っていないラベルは学習データのデータ分布外なので、推定値が0や1に近くても標準偏差が大きくなるかと思っていたのですが。これはDropoutを使った不確かさ評価と同じ結果ですね。
f:id:st1990:20190815200528p:plainf:id:st1990:20190815200510p:plainf:id:st1990:20190815195819p:plainf:id:st1990:20190815200513p:plainf:id:st1990:20190815200516p:plainf:id:st1990:20190815200522p:plainf:id:st1990:20190815200525p:plainf:id:st1990:20190815200531p:plainf:id:st1990:20190815200534p:plainf:id:st1990:20190815200539p:plainf:id:st1990:20190815200541p:plain

以上