YOTO(YOU ONLY TRAIN ONCE)を不均衡データ対策の損失関数に適用して画像分類してみた
この記事は、YOTO(YOU ONLY TRAIN ONCE)の雰囲気を掴むことを目的として、不均衡データ対策の損失関数にYOTOを適用して画像分類してみた記録です。
YOTOを使うことで、1つのモデルでMajorクラスの性能が良いモデル or Minorクラスの性能が良いモデルをテスト時に選択できました。
検証に使ったコードはgithubにあります。
GitHub - statsu1990/yoto_class_balanced_loss: Unofficial implementation of YOTO (You Only Train Once) applied to Class balanced loss
記事の概要
YOTO(YOU ONLY TRAIN ONCE)について
- ICLR2020でGoogleから発表されたDeep learningに関する技術。
- YOTOを適用したモデルは、モデルに損失関数のハイパーパラメータを入力できるようになり、推論時にハイパーパラメータの値を変えられる。普通であればハイパーパラメータの値毎に異なるモデルが必要となるが、YOTOでは1つのモデルで異なるハイパーパラメータに対応できるようになる。
- 論文中では画像圧縮タスクとstyle transferタスクに適用されている。YOTOを使うことで、1つのモデルで圧縮~画質の調整、style強度の調整ができるようになっていた。
- ICLR2020でGoogleから発表されたDeep learningに関する技術。
不均衡データ対策の損失関数へのYOTOの適用
- この記事ではYOTOを不均衡データでの画像分類タスクに適用してみた。
- 不均衡データでの分類タスクでは、クラス毎のデータ数に偏りがある。不均衡データをそのまま学習するとデータ数が多いクラス(Majorクラス)が優先的に学習され、データ数の少ないクラス(Minorクラス)の推定精度が低くなることがある。
- 損失関数を工夫することでMinorクラスの推定精度を上げることができる。この記事ではClass Balanced Lossを使う。Class Balanced Lossには、「Minorクラスをどれだけ優先するか」というハイパーパラメータβがある。βが小さければMajorクラスが優先され、βが大きければMinorクラスが優先される。YOTOとClass Balanced Lossを組み合わせることで、1つのモデルでMajorクラスの性能が良いモデル or Minorクラスの性能が良いモデルをテスト時に選択できるようになるはず。
- この記事ではYOTOを不均衡データでの画像分類タスクに適用してみた。
- 検証
- YOTOとClass Balanced Lossを組み合わせることで、1つのモデルでMajorクラスの性能が良いモデル or Minorクラスの性能が良いモデルをテスト時に選択できるようになるか検証した。
- Cifar10とCifar100を不均衡データにしたデータセットを使った。モデルにはResNet18を使った。
- 1つのモデルでMajorクラスの性能が良いモデル or Minorクラスの性能が良いモデルをテスト時に選択できること(Minorクラスの優先度を変えた推論をできること)を確認した。
- ハイパーパラメータ固定で学習したモデルとYOTOで学習したモデルの性能を比較すると、同じハイパーパラメータでも性能は完全には一致しなかった。Majorクラスの分類精度は若干下がり、Minorクラスの分類精度は向上した。同じモデルサイズであれば性能が下がるのは論文のとおりなので納得だが、Minorクラスの分類精度が向上したのは意外だった。ハイパーパラメータ固定だと学習が不安定になるケースでも、YOTOだと安定して学習できて性能上がることがあるのかも(※個人的な推測です)。
- YOTOとClass Balanced Lossを組み合わせることで、1つのモデルでMajorクラスの性能が良いモデル or Minorクラスの性能が良いモデルをテスト時に選択できるようになるか検証した。
YOTO (YOU ONLY TRAIN ONCE)について
YOTOはICLR2020でGoogleから発表されたDeep learningに関する技術です。論文のタイトルは"YOU ONLY TRAIN ONCE: LOSS-CONDITIONAL TRAINING OF DEEP NETWORKS"で、YOTOとはYOU ONLY TRAIN ONCEの略です。Object detectionで使われるYOLO(You Only Look Once)のパロディ的な名称ですね。カッコエエ。
論文:https://openreview.net/pdf?id=HyxY6JHKwr
Google AI Blog:Google AI Blog: Optimizing Multiple Loss Functions with Loss-Conditional Training
YOTOの概要
ざっくり概要を説明します。Google AI Blogにコンセプトが良くまとまっているので、詳しくはそっちを見てください。
Deep learningで使われる損失関数にはハイパーパラメータを含むものがあり、ハイパーパラメータの値が異なると学習されたモデルの性能や特徴が変わります。例えば、画像圧縮モデル(画像を入力すると合宿した画像を出力する)を学習するときには、損失関数には「ファイルサイズと画質のどちらを優先するか」というハイパーパラメータが含まれます。次式のような損失関数であれば、λがそのハイパーパラメータです。
圧縮サイズを優先するモデルと画質を優先するモデルが欲しければ、ハイパーパラメータを変えて学習させ、別々に2つのモデルを作る必要があります。
YOTOを適用したモデルでは、損失関数のハイパーパラメータを入力として受け取り、そのハイパーパラメータに対応した出力を出せるようになります。つまり、ハイパーパラメータ毎にモデルを用意する必要がないので学習時間や容量が減って嬉しいです。また、推論時にハイパーパラメータの値をいじって出力を調整できるのも嬉しいです。
論文で紹介されていたStyle transferでは、下図のように1つのモデルでStyle強度を調整できるようになっています。
上記の画像圧縮の例とすると、従来とYOTOは以下のような違いがあります。
従来
- ファイルサイズ優先させた損失関数を使って学習したモデルに画像を入力し、ファイルサイズ優先の圧縮画像を出力する。
- 画質優先させた損失関数を使って学習したモデルに画像を入力し、画質優先の圧縮画像を出力する。
- ファイルサイズ優先させた損失関数を使って学習したモデルに画像を入力し、ファイルサイズ優先の圧縮画像を出力する。
YOTO
- YOTOを使って学習したモデルに画像とファイルサイズ優先のパラメータを入力し、ファイルサイズ優先の圧縮画像を出力する。
- YOTOを使って学習したモデルに画像と画質優先のパラメータを入力し、画質優先の圧縮画像を出力する。
- YOTOを使って学習したモデルに画像とファイルサイズ優先のパラメータを入力し、ファイルサイズ優先の圧縮画像を出力する。
YOTOのモデル構造
YOTOのモデル構造は下図のとおりです。
簡単に説明しておきます。
- Main network:従来のモデル構造。例えば、画像分類だったらResNet、SegmentationだったらUNetなど。
- α:損失関数のハイパーパラメータ。学習時には指定した値域内からランダムサンプリングします。推論時は使いたい値に設定する。
- Conditioning network:αを入力とし、σとμという量を出力するNeural Network。Main networkと一緒に学習する。論文ではσ、μそれぞれを出力するMLP(multi layer perceptron)が使われていた。
- Conditioning networkとMain networkのつなぎ:Conditioning networkの出力σとμをFiLMという手法でMain networkに入れて、αによる影響をMain networkに組み込む。
- FiLM:Feature-wise Linear Modulationの略。下図(βがμ、γがσに対応)の方法でσとμをMain networkに入れる。style transferで使われるAdaINと似た手法である。σとμの直前のBatch Normalization layerでaffine変換をしないようにするのを忘れないように注意する。
YOTOの学習と推論手順
学習と推論の手順を説明しますが、オフィシャルな実装が公開されていないので、正確ではないところがあるかもしれません。
学習の準備
- ハイパーパラメータの値域とサンプリング方法を決める。[0,1]の範囲で一様分布からサンプリングなど。[1e-5,1]の範囲で対数一様分布からサンプリングなど。
- ハイパーパラメータの値域とサンプリング方法を決める。[0,1]の範囲で一様分布からサンプリングなど。[1e-5,1]の範囲で対数一様分布からサンプリングなど。
学習手順
- バッチサイズの数だけInputを生成。バッチサイズの数だけハイパーパラメータをサンプリング(バッチごとに1つのハイパーパラメータでもいいかも)。
- Inputとサンプリングしたハイパーパラメータをモデルに入力し、Lossを計算。モデルの重みを更新。
- 1~2を繰り返す。
- バッチサイズの数だけInputを生成。バッチサイズの数だけハイパーパラメータをサンプリング(バッチごとに1つのハイパーパラメータでもいいかも)。
推論手順
- Inputと任意のハイパーパラメータ値をモデルに入力し、出力を得る。
- Inputと任意のハイパーパラメータ値をモデルに入力し、出力を得る。
その他メモ
- 論文によると、モデルサイズが小さい場合、ハイパラ固定のモデルよりも性能が下がるらしい。モデルサイズを大きくしていくとその差が小さくなる。
不均衡データ対策の損失関数 Class Balanced Loss
後述する検証で使う不均衡データ対策の損失関数であるClass Balanced Lossについて説明します。
論文:[1901.05555] Class-Balanced Loss Based on Effective Number of Samples
日本語ブログ:不均衡データを損失関数で攻略してみる - Qiita
不均衡データをそのまま学習に使うと、データ数が多いクラス(Majorクラス)が重点的に学習され、データ数が少ないクラス(Minorクラス)に関する性能が悪化することが知られており、いろいろな対策が考えられています。対策の1つが損失関数の工夫で、データ数の逆数でLossの寄与率を重みづけ、Focal Lossなどがあります。
Class Balanced Lossは、データ数の逆数でLossの寄与率を重みづけの亜種です。
説明のために、クラス数がCである分類問題について考えます。
通常のSoftmax Cross Entropy Lossは次式のようになります。
ただし、iはバッチ内のサンプル番号、はサンプルiの正解ラベル、はラベルmに対応するソフトマックス出力です。
データ数の逆数でLossの寄与率を重みづけしたSoftmax Cross Entropy Lossは次式のようになります。
ただし、はラベルmのデータ数です。式からわかるようにデータ数の少ないクラスほどLossへの寄与率が大きくなるため、Minorクラスが重点的に学習されます。
Class balanced Softmax Cross Entropy Lossは次式のようになります。
詳しい意味は上述のブログや論文を参照してください。大事なポイントは以下のとおりです。
- データ数によってLossの寄与率が決まる。データ数が少ないほど寄与率は大きくなる。
- Loss(損失関数)のハイパーパラメータはβのみ。
- βの値域は[0,1]であり、β=0のときはCE、β=1のときはICEと一致する。その間の値のときはCEとICEの間のような感じ。すなわち、β=0のときは不均衡を考慮せずに学習し(Majorクラスを重点的に学習する)、β=1のときはMinorクラスを重点的に学習すると言えます。
ちょうどいいβを選択すると性能がよくなる。
Class balanced Lossではβのみがハイパーパラメータなので、後述する検証ではβを対象としてYOTOの学習をします。
検証
YOTOがどんなもんなのか感覚を掴むため、YOTOをClass balanced lossに適用してみました。ゴールは、YOTOで学習したモデルでClass balanced lossのβを動かし、MajorクラスとMinorクラスで分類性能が変わっていく様子を見ることです。Class balanced lossを選んだのは、Style transferとかより簡単そうだったからです。
目的
- YOTOを使って、1つのモデルだけでMajorクラスの性能が良いモデル or Minorクラスの性能が良いモデルをテスト時に選択できるようにする。
概要
- 不均衡データセットの画像分類タスクを対象とする。
- ResNet18とClass balanced lossを組み合わせたモデルを以下のケースで学習させ、Majorクラス、Minorクラス及び全クラスでのAccuracy等を比較する。
- YOTOでβを調整できるようにしたモデル
- βを固定して学習したモデル
- YOTOでβを調整できるようにしたモデル
条件
データセット
- Case 1
- Cifar 10でクラス毎のデータ数を不均衡にしたもの。trainデータではクラス0,2,4,6,8を各5000枚、クラス1,3,5,7,9を各250枚とした。testデータはデフォルト枚数(各クラス1000枚)のまま。
- Cifar 10でクラス毎のデータ数を不均衡にしたもの。trainデータではクラス0,2,4,6,8を各5000枚、クラス1,3,5,7,9を各250枚とした。testデータはデフォルト枚数(各クラス1000枚)のまま。
- Case 2
- Cifar 100でクラス毎のデータ数を不均衡にしたもの。trainデータではクラス0~49を各500枚、クラス50~99を各25枚とした。testデータはデフォルト枚数(各クラス100枚)のまま。
- Cifar 100でクラス毎のデータ数を不均衡にしたもの。trainデータではクラス0~49を各500枚、クラス50~99を各25枚とした。testデータはデフォルト枚数(各クラス100枚)のまま。
- Case 3
- Cifar 100でクラス毎のデータ数を不均衡にしたもの。trainデータではクラス0~49を各500枚、クラス50~99を各50枚とした。testデータはデフォルト枚数(各クラス100枚)のまま。
- Cifar 100でクラス毎のデータ数を不均衡にしたもの。trainデータではクラス0~49を各500枚、クラス50~99を各50枚とした。testデータはデフォルト枚数(各クラス100枚)のまま。
- Case 1
モデル
学習
- epoch数100、バッチサイズ128、SGD(momentum=0.9, weight decay=5e-4)、学習率0.1(epoch 50、85で0.1倍)
- YOTOでのβのサンプリング範囲は[0.9, 0.99999]で、1 - βの対数一様分布からサンプリング。
- epoch数100、バッチサイズ128、SGD(momentum=0.9, weight decay=5e-4)、学習率0.1(epoch 50、85で0.1倍)
実装
github参照
GitHub - statsu1990/yoto_class_balanced_loss: Unofficial implementation of YOTO (You Only Train Once) applied to Class balanced loss
結果
Case 1
Cifar 10で、trainデータのクラス0,2,4,6,8を各5000枚、クラス1,3,5,7,9を各250枚にしたデータセットのテスト結果は下図のとおりです。結果から以下のことがわかります。
- YOTOの結果より、1-βが大きい(βが大きく、Majorクラスが重点的に学習される)ときにはMajorクラスの分類精度が高く、1-βが小さい(βが小さく、Minorクラスが重点的に学習される)ときにはMinorクラスの分類精度が高くなった。
狙いどおり、テスト時にβを変えることで、1つのモデルでMajorクラスの性能が良いモデル or Minorクラスの性能が良いモデルを選択できるようになった。 - 1-βが大きいときYOTOは性能が低い。これは論文と一緒で、モデルが小さいことが原因と思われる。
- 1-βが小さいときYOTOの性能が高い。これは不思議で原因はよくわからない。1-βが大きいときと同じで、性能が下がると思っていた。1-βが大きいときはMajorクラスとMinorクラスの重みの比が20倍近くになるので、固定ハイパラだと学習が不安定になるが、YOTOだと安定するとかあるのかもしれない。
参考までに学習曲線をのせておきます。
Case 2
Cifar 100で、trainデータのクラス0~49を各500枚、クラス50~99を各25枚としたデータセットでのテスト結果は下図のとおりです。Case 1と同じ傾向となっています。
Case 3
Cifar 100で、trainデータのクラス0~49を各500枚、クラス50~99を各50枚としたデータセットでのテスト結果は下図のとおりです。Case 1, 2と同じ傾向となっています。ただし、βが小さいときの性能低下が大きいですね。
結果まとめ
- Case1~3のすべてで、YOTOを使うことで1つのモデルでMajorクラスの性能が良いモデル or Minorクラスの性能が良いモデルをテスト時に選択できた。
- Class balanced lossのβが大きいとき(Minorクラスの重み大)に固定ハイパラより性能がかなり良くなった。理由はわからないが、YOTOによって不安定な学習が安定したとかあるのかもしれない。
まとめ
- YOTOの概要を説明した。
- Class balaned lossの概要を説明した。
- 不均衡データでの画像分類タスクにおいて、YOTOをClass balanced lossを組み合わせることで、1つのモデルでMajorクラスの性能が良いモデル or Minorクラスの性能が良いモデルをテスト時に選択できた。
- 基本的に固定ハイパラよりYOTOのモデルは性能が落ちるが、性能が上がるケースもあった。原因はわからないが、不安定な学習を安定させられる効果がYOTOにはあるのかもしれない。
思ったこと
- 1つのモデルでアンサンブルできそう。
- ハイパーパラメータサーチに使えるかと思ってたけど、固定ハイパラと不一致がそれなりにあるので難しそう。Focal lossで試したけどうまくいかなかった。