statsuのblog

愛知のデータサイエンティスト。自分の活動記録。主に機械学習やその周辺に技術について学んだことを記録していく予定。

ReZeroの収束性と精度について画像認識(Cifar100)で検証した記録

ReZeroという最近提案されたDeep learning関連の手法を画像認識(Cifar100)で試したのでその記録です。
結論としては、Cifar100での画像認識では効果なかったです。(なんかミスしている可能性もなくはない)

本記事の概要

  • ReZeroの概要
  • ReZeroの実装
  • ReZeroの検証
  • まとめ

ReZeroの概要

ReZeroは"ReZero is All You Need: Fast Convergence at Large Depth"という論文で提案された手法です。(〇〇 is All You Needって言ってみたい…)
[2003.04887] ReZero is All You Need: Fast Convergence at Large Depth
詳細は論文を読んでもらうとして、ここでは論文の主張について概要を簡単に説明します。

ReZeroとは

Deep learningでは、深い層数で学習するために正規化(BatchNorm、LayerNormなど)やResidual connection(ResNetなどで使われるやつ)が使われます。これらを使わないと層数が多い場合には勾配消失等の問題によりうまく学習できません。

ReZeroは、従来の正規化やResidual connectionの代わりとして使え、収束性が上がるらしいです(ただし、BatchNormに関してはReZeroのアプローチを補完するらしい。よくわからん。)。

ReZeroの構造は簡単で、Residual connectionにResidual weight αという学習可能パラメータを追加した形となっています。αの初期値は0にします。これを以下のように正規化やResidual connectionの代わりに使います。
f:id:st1990:20200329025620p:plain

f:id:st1990:20200329025601p:plain

f:id:st1990:20200329025644p:plain

効果

論文によると、以下の効果があるそうです。

  • より深い層数のアーキテクチャの学習が可能となる。10000層の全結合ネットワークや100層以上のTransformerを学習できた。
  • 収束性が早くなる。Transformerでは、enwiki8ベンチマークで1.2BPBに56%速く到達した。ResNetでは、Cifar 10で85%の精度に32%速く到達した。ただし、Cifar 10での検証結果は、なぜかおまけみたいな扱いで少ししか触れられていません。

実装

私はpytorchで以下のように実装しました。簡単で良い。

import torch
import torch.nn as nn
from torch.nn.parameter import Parameter

class ReZeroShortcut(nn.Module):
    def __init__(self, alpha=0.0):
        super(ReZeroShortcut, self).__init__()
        self.alpha = Parameter(torch.ones(1) * alpha)

    def forward(self, shortcut, x):
        return shortcut + self.alpha * x

GitHub - statsu1990/ReZero-Cifar100: Verification of ReZero ResNet on cifar 100 dataset

検証

Cifar10

論文ではCifar10で以下の結果が出ています。

  • ResNet56のresidual connectionをすべてReZero connectionに置き換える。
  • validation errorが7.37±0.06 %から6.46±0.05%に向上した。
  • validationエラーが15%未満になるepoch数が32±14%減少した。

非公式ですがCifar10での実験結果をgithubで見つけました。比較の仕方がよくわからないところもあるけど、ReZeroの効果があるようには見えない。。
GitHub - fabio-deep/ReZero-ResNet: Unofficial pytorch implementation of ReZero in ResNet

Cifar100

さて、私はCifar100でPreAct-ResNetを学習させ、ReZeroの有無で精度と収束性がどう変わるか検証しました。検証に使ったコードはこちらです。
GitHub - statsu1990/ReZero-Cifar100: Verification of ReZero ResNet on cifar 100 dataset

目的

  • ReZeroの有無による精度と収束性の違いを確認する。

検証方法・条件

以下の条件で学習させ、ReZero有無での精度と収束性を比較します。

  • データ
    • cifar 100
  • モデル
    • ベースモデル:PreAct ResNet 18, 50
      [1603.05027] Identity Mappings in Deep Residual Networks
    • モデル with ReZero:ベースモデル中のresidual connectionをすべてReZero connectionに変えたもの。residual weight αの初期値は0。
    • モデル with ReZero (個人的改良版):ベースモデル中のresidual connectionをすべてReZero connectionに変えたもの。αの代わりにtanh(α)を使用(αの初期値は0)。αが異常に大きくなることを避けるため、値域を限定する目的でtanh(α)としてみた。
  • 学習方法
    • Cross entropy loss
    • SGD, 学習率0.1 (60、120、160epochで学習率を0.2倍し小さくする)、epoch数200、バッチサイズ128
    • Data augmentation (random flip, random shift scale rorate)

検証結果

(2020/3/30修正:rezero preact-resnet18の結果を少し修正)

学習曲線がこちらです。精度も収束性も改善していません。

PreAct ResNet 18

f:id:st1990:20200330003259j:plain
loss_preact-resnet18
f:id:st1990:20200330003318j:plain
accuracy_preact-resnet18

PreAct ResNet 50

f:id:st1990:20200329141830j:plain
loss_preact-resnet50
f:id:st1990:20200329142403j:plain
accuracy_preact-resnet50

学習完了後のモデル with ReZeroのResidual weight αの値は以下のとおりです。stageは下表のlayer nameのconv2_x~conv5_xに、blockは3×3,64や1×1,64に対応しています。
概ね-2~2程度の値をとるようです。たまに絶対値が大きい値がありますね。

f:id:st1990:20200330003357j:plain
alpha_preact-resnet18
f:id:st1990:20200329142729j:plain
alpha_preact-resnet50
f:id:st1990:20200329143352p:plain
https://arxiv.org/pdf/1512.03385.pdf

まとめ

  • 深い層でも学習できるようになる&学習の収束性を上げることができるというReZeroの概要を説明しました。
  • ReZeroはCifar100 + Preact-ResNetでは効果がなかったです。なんなら逆に収束性も精度も悪くなりました。なんかやり方間違っているのかな。誰か間違いに気づいたら教えてください。
  • Deep Learningの論文の再現性って闇深い。ちゃんと検証してみたら効果なかったわ、みたいな論文が定期的に出ている気がする。