statsuのblog

愛知のデータサイエンティスト。自分の活動記録。主に機械学習やその周辺に技術について学んだことを記録していく予定。

Kaggleのくずし字認識(Kuzushiji recognition)コンペで15位になった

Kaggleのくずし字認識(Kuzushiji recognition)コンペで15位/293チームになりました。
https://www.kaggle.com/c/kuzushiji-recognition/leaderboard

コードをのっけておきます。
github.com

以下、記録です。

1. 記事の概要

  • くずし字認識(Kuzushiji recognition)コンペについて
  • 私の解法
  • 感想と反省

1. くずし字認識(Kuzushiji recognition)コンペについて

今回のコンペは以下のような内容でした。

  • 目的:文書画像から崩し字の位置を検出し、崩し字の種類を分類する。
  • 与えられるデータ:崩し字が書かれた文書画像(3881枚)、崩し字の位置を示すバウンディングボックス、字の種類。

Kuzushiji Recognition | Kaggle

崩し字認識の難しい点は以下のとおりです。

  • 画像内に文字が多数ある(0~614個)。
  • 文字の大きさがばらばら。小さいもの、大きいもの、細長いもの。
  • 文字の種類が多数(3422種)あり、出現回数の偏りが大きい(出現数1~24,685個)。まさに不均衡データ。

2. 私の解法

私の解法の概要は以下のとおりです。

  • Centernet(HourglassNetバックボーン)で文字を検出
  • Resnet baseなモデルで文字を分類
    f:id:st1990:20191024214148p:plain
    文字認識のフロー

※Centernetについては原論文と神Notebookを参照してください。
[1904.07850] Objects as Points CenterNet -Keypoint Detector- | Kaggle

最終的なプライベートリーダーボードのスコアは0.900でした。

Preprocessing

  • to gray scale
  • gaussian filter
  • gamma correction
  • ben's preprocessing

Detection

inference

2段階のcenternetを使って、以下の手順で文字のバウンディングボックスを検出します。

  • ステップ1:画像を512x512にリサイズして、centernet1でバウンディングボックス1を推定。
  • ステップ2:バウンディングボックス1を使って一番外側のバウンディングボックスの外側を除去。
  • ステップ3:画像を512x512にリサイズして、centernet2でバウンディングボックス2を推定。
  • ステップ4:バウンディングボックス1と2をアンサンブルし、最終的なバウンディングボックスを作成。

モデルアーキテクチャ

  • centernet1は2つのcenternet(1スタックのhourglassnetベース)のアンサンブル
  • centernet2は2つのcenternet(1スタックのhourglassnetベース)のアンサンブル

training

centernet1については以下のとおり。

  • 学習データ:全データの80%を使用。(データの分割を乱数で変えて2つのモデルを作成)
  • データ拡張:水平移動、輝度調整 データ拡張は過学習防止のために必須でした。

centernet2については以下のとおり。

  • 学習データ:全データの80%を使用。(データの分割を乱数で変えて2つのモデルを作成)
  • データ拡張:Random erasing、水平移動、輝度調整 バウンディングボックスの外側を除去した画像を入力とするため水平移動のデータ拡張の効果が薄かったため、Randome erasingが必須でした。

Classification

inference

3つのResnet baseのアンサンブルモデルを使って、以下の手順で文字ラベルを分類します。

  • ステップ1:推定したバウンディングボックスを使って元の画像から文字画像をクロップし、64x64にリサイズ。
  • ステップ2:3つのResnet baseのモデルで文字ラベルを分類する(テスト時augmentationは水平移動9種類)。
  • ステップ3:3つのモデルの分類結果をアンサンブルし、最終的な分類結果を推定。

モデルアーキテクチャ

  • Resnet base1:最終出力層の前にlog(バウンディングボックスのアスペクト比)を入力する。
  • Resnet base2:Resnet base1と学習データを変えたもの。
  • Resnet base3:上述のDetectionモデル、Resnet base1と2のアンサンブルモデルからpseudoラベリングした入力を学習データに加えたもの。構造はResnet base1と同じ。

training

上述のとおり学習データを変えていることとpseudoラベリングを使っていること以外は各モデルで同じ。

  • 学習データ:全データの80%を使用。
  • データ拡張:水平移動、回転、ズーム、Random erasing

3. 感想と反省

  • 物体検出を初めてやったのですが、画像認識より応用よりでなかなか面白かったです。
  • 自分なりにpipelineを組んで管理や組合せ変更をしやすいようにしてみました。使いやすかったし、作業時間を減らせたのでこれからも続けたい。
  • 自宅のGTX1080でDLの学習をすべて実施したのですが、学習時間と推論時間が長すぎて試行回数を増やせなかったです。GPU増設するか。。画像処理のCPUの部分もボトルネックにならないようにもっと気を付けて実装する必要あったなーと思います。
  • けっこう自前主義で、もろもろを自分で実装することが多かったのだが、公開されているものをもっと積極的に使うべきでした。自分の練習にはいいけど、本質にかける時間が足りなくなる。
  • pytorchの方が色々なモデルが転がっている気がする。kerasから乗り換えるか。。