statsuのblog

愛知のデータサイエンティスト。自分の活動記録。主に機械学習やその周辺に技術について学んだことを記録していく予定。

【論文メモ】Rethinking Normalization and Elimination Singularity in Neural Networks

「Rethinking Normalization and Elimination Singularity in Neural Networks」をざっと読んだのでそのメモです。
arxiv.org github.com

論文の概要

  • 画像認識に使うDeep neural network (DNN)の構造お話。
  • Batch normalization (BN)はDNNでとてもよく使われる正則化層であるが、バッチサイズが小さいときに性能が落ちる。この論文ではBNの代替としてBatch-Channel Normalization (BCN)を提案している。BCNはBNより性能がよく、小さいバッチサイズでも使える。
  • 検証ではBCN+Weight Standalization (WS)※がBNやGroup normalization (GN)+WSよりも良い精度を出していた。
  • BCNの導出前の考察として、なぜGNやLayer Normalization (LN)がBNに劣るか、なぜWSが効くかについて、Reluによるsingularity発生という観点から述べられている。

※WS : 最近話題のBiT-Lでも使われているホットな奴です。
[1903.10520] Weight Standardization

Batch-Channel Normalization

BCNはバッチサイズが大きいときと小さいときで処理が異なる。

バッチサイズが大きいとき

入力をXとすると、BCNの出力BCN(X)は次式で表される。
BCN(X)=GN(BN(X))
要するに、BNしてからGNしているだけ。簡単。
一見冗長であるが非線形性が増したりするので意味があると論文では言及されてる。

バッチサイズが小さいとき

入力をXとすると、BCNの出力BCN(X)は次式で表される。
BCN(X)=GN(BN'(X))
さっきと同じように見えるが、このBN'は普通のBNの以下の点が異なる。

  • 普通のBN

    • 学習時:バッチ内でのXの平均、標準偏差を使って計算したμ及びσを使う。推論時に使う用のμ、σを、バッチ内Xの平均、標準偏差の指数移動平均で計算する。
  • BN'

    • 学習時:μ、σを、バッチ内Xの平均、標準偏差の指数移動平均で計算する。このμ、σを使ってBN'(X)を計算する。普通のBNでは推論時に使ってる"移動平均で計算されるμ、σ"を学習時にも使っているイメージです。アルゴリズムは以下のとおりです。

f:id:st1990:20200122001936p:plain
BCNのアルゴリズム

検証結果

画像認識ではCifar10, Cifar100, ImageNet、物体検出ではCOCO、セグメンテーションではPASCALで検証されています。
BCN+WSが強いですね。物体検出とセグメンテーションではBNと比較していないのが残念。

f:id:st1990:20200122002649p:plain
検証1
f:id:st1990:20200122002717p:plain
検証2

感想

  • BCNではBNとGNの両方を使っています。単純に考えると冗長としか思えないのに、それが効くって面白いですね。論文ではなぜ効くかをきちんと考察しています。ブラックボックスにせずに、各層がどんな役割を持っているかのイメージをきちんと持つことが大事ですね。
  • バッチサイズが小さいときのBCNのμ、σはforward時に更新されるんですね。これだとgradient accumulateするときに支障になる?僕のような雑魚GPUユーザー的には、gradient accumulateで精度が落ちない手法がうれしい。その点でGN+WSは良いね。