Deep EnsemblesでDeep Learningの不確かさを評価する
前回の記事に引き続き、Deep learningの不確かさ評価についてです。
今回は、「Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles」という論文で紹介されているアンサンブルで不確かさを評価する手法の検証を実施しました。
検証コードはgithubgithubにあげてあります。
github.com
以下、要点のみメモ。
アンサンブル学習での不確かさ評価
論文及び参考サイト
- 論文
[1612.01474] Simple and Scalable Predictive Uncertainty Estimation using Deep Ensembles - 論文の概要説明(日本語)
私たちが愛した3つのNIPS論文 | techcareer magazine | エンジニアの転職(キャリア)情報満載のWebメディア
手法
- 目的
- DLの推定値の不確かさを定量評価する。
- DLの推定値の不確かさを定量評価する。
- 手法の概要
- 通常のDLでは推定値を出力する。例えば、入力情報から土地価格を推定したければDLが土地価格を出力するようにネットワークを作る。
- 提案されているDLでは、推定値の平均と分散(または推定値の確率分布のパラメータ)を出力するようにネットワークを作る。最小化する損失関数は、推定値の確率分布の対数尤度である。※損失関数は推定値、分散、推定値の正解データから成り、分散の正解データは不要。
- ブーストラップサンプリングで元データセットからM個のデータセットを作成し、それぞれのデータセットを使ってM個のネットワークを学習する。
- M個のネットワークの推定値の平均値を最終的な推定値とする。
- M個のネットワークの混合分布の分散を最終的な分散とする。
- 通常のDLでは推定値を出力する。例えば、入力情報から土地価格を推定したければDLが土地価格を出力するようにネットワークを作る。
kerasでの実装例
ラベル0 or 1の二値分類を対象としたネットワークを作った実装例を紹介します。
問題設定は以下のとおりです。
- ネットワークの推定値μは、ラベルxの平均値であり0から1の間の値。
- ラベルxの確率分布が、定義範囲を限定した正規分布p(x : μ, ) (0≦x≦1)に従うと仮定する。このもネットワークで出力する。
- 損失関数は、定義範囲を限定した正規分布p(x : μ, ) (0≦x≦1)の対数尤度関数。
kerasでこれらを実装するためには、(1)出力が2種類のネットワーク作成、(2)自作の損失関数の使用、が必要となります。以下、実装例です。
(1) 出力が2種類のネットワーク
推定値μはsigmoid関数、はsoftplus関数で計算します。が0に近づきすぎると数値的に不安定になるため、1e-6を足しています。
def built_model(self, input_shape=None): """ model input: image shape(32, 32, 3) model output: probability of being label1, uncertainty score. Range is [0,1] and [0,1], respectively. """ # constants if input_shape is None: # assume cifar10 image input_shape = (32, 32, 3) # model structure input_img = Input(input_shape) h = input_img h = Conv2D(32, (3, 3), padding='same')(h) h = BatchNormalization()(h) h = Activation('relu')(h) h = Conv2D(32, (3, 3))(h) h = BatchNormalization()(h) h = Activation('relu')(h) h = MaxPooling2D(pool_size=(2, 2))(h) h = Conv2D(64, (3, 3), padding='same')(h) h = BatchNormalization()(h) h = Activation('relu')(h) h = Conv2D(64, (3, 3))(h) h = BatchNormalization()(h) h = Activation('relu')(h) h = MaxPooling2D(pool_size=(2, 2))(h) oup_cnn = Flatten()(h) oup_cnn = Dense(32)(oup_cnn) oup_cnn = BatchNormalization()(oup_cnn) oup_cnn = Activation('relu')(oup_cnn) # expec h_expec = Dense(1)(oup_cnn) h_expec = Activation('sigmoid')(h_expec) # var h_var = Dense(1)(oup_cnn) h_var = Activation('softplus')(h_var) h_var = Lambda(lambda x: x + 1e-6, output_shape=(1,))(h_var) oup = Concatenate(axis=-1)([h_expec, h_var]) # model self.model = Model(inputs=input_img, outputs=oup) self.model.summary() return
(2) 自作の損失関数の使用
損失関数を定義します。ここでは定義範囲を区切った正規分布を使っているので、規格化係数([0, 1]の範囲での積分値を1にする)の計算があってごちゃごちゃしています。また、が大きい値を取りやすかったので正則化項loss_reg_varを加えています。
def part_norm_dist_log_likelihood(self, y_true, y_pred): """ expec = y_pred[:,0] var = y_pred[:,1] -ln(L) = ln(I) + 0.5 * ln(var) + 0.5 * (x - expec)^2 / var """ expec = y_pred[:,0:1] var = y_pred[:,1:2] loss_var = 0.5 * K.log(var) loss_l2 = 0.5 * K.square(y_true - expec) / var I = 0.5 * (tf.math.erf((1.0 - expec) / K.sqrt(2.0 * var)) - tf.math.erf((0.0 - expec) / K.sqrt(2.0 * var))) loss_I = K.log(I) loss_reg_var = K.sqrt(var) * 16.0 # regularization of var loss = loss_I + loss_var + loss_l2 + loss_reg_var return loss
モデルの損失関数として自作損失関数を以下の方法で設定します。
self.model.compile(loss=self.part_norm_dist_log_likelihood, optimizer='nadam', metrics=['accuracy'])
検証
cifar10の犬猫画像の二値分類を対象として、上記の手法で不確かさ評価を実施してみました。
検証の設定は以下のとおりです。
- 対象は犬猫画像の二値分類問題。
- 画像データとしてcifar10を使用。
- ラベルは犬が1、猫が0。
- 上記手法で推定値の不確かさを評価。
- 推定値、分散、損失関数の設定は上述の「kerasでの実装例」のとおり。
- アンサンブルの数は10個。
- 学習に使っていないラベル(飛行機、自動車、鳥、鹿、カエル、馬、船、トラック)についても評価を実施。
- 推定値、分散、損失関数の設定は上述の「kerasでの実装例」のとおり。
※以下のグラフのuncertainty coefは、アンサンブルで求めた標準偏差です。
犬猫分類の不確かさ評価結果
犬猫の分類結果を以下に示します。推定値が1のとき犬、0のとき猫です。
推定値が0.5に近いほど推定値の標準偏差(不確かさ)が大きい傾向がわかります。
テストデータのROC曲線を下図に示します。
「threshold=y」ラベルは、犬か猫か判定する閾値を単純に推定値としてROC曲線を描いたものです。推定値>閾値であれば犬と判定されます。
「threshold=0.5+a * std」ラベルは、閾値を0.5+astd (-∞<a<∞)としてROC曲線を描いたものです。推定値>0.5+astdであれば犬と判定されます。推定値が0.5に対して不確かさを含めてどの程度余裕があるかを表しています。
結果を見ると…ほぼ同じですね。
学習に使っていないラベルでの不確かさ評価結果
学習に使っていないラベル(飛行機、自動車、鳥、鹿、カエル、馬、船、トラック)についても不確かさ評価を実施しました。
推定値の標準偏差のヒストグラム、推定値の標準偏差に対する推定値のグラフを以下に示します。
0.5付近の推定値が増えるので、ヒストグラムのピークとなる標準偏差の値が大きくなったことがわかります。しかし、標準偏差vs推定値の分布は犬猫画像と変わりなく見えます。学習に使っていないラベルは学習データのデータ分布外なので、推定値が0や1に近くても標準偏差が大きくなるかと思っていたのですが。これはDropoutを使った不確かさ評価と同じ結果ですね。
以上