iMind Developers Blog

iMind開発者ブログ

Open Image Dataset v5のデータを使って物体検知

概要

Open Image Dataset v5(以下OID)のデータを使って、SSDでObject Detectionする。

全クラスを学習するのは弊社の持っているリソースでは現実的ではない為、リンゴ、オレンジ、苺、バナナの4クラスだけで判定するモデルを作ってみる。

バージョン情報

  • Python 3.7.3
  • Keras==2.2.4
  • tensorflow-gpu==1.13.1
  • tensorflow/models v1.13.0

データ量

OIDのApple, Orange, Strawberry, BananaのIDは下記。

/m/014j1m,Apple
/m/0cyhj_,Orange
/m/07fbm7,Strawberry
/m/09qck,Banana

このIDに該当するラベルと画像だけを抽出して学習する。

各ラベルに対する写真の枚数。

label train validation test
Apple 3,898 128 407
Orange 6,195 176 916
Strawberry 7,944 330 816
Banana 1,612 26 167
19,649 660 2,306

学習データ全体では19,649。このくらいであればGPUがついた個人のPCでも扱える。

利用するライブラリ

tensorflowのmodelsに置いてあるobject_detectionを利用する。

https://github.com/tensorflow/models/tree/master/research/object_detection

正式なライブラリという立ち位置ではないけどApacheライセンスでメンテもちょくちょく行われているので、メンテが行われていない古い実装を使うより引っかかる点が少なく済む。

OID V4を利用した例も書いてある。SSDだけでなくFaster R-CNNやMask-RCNN、Android向けの実装等も用意されている。

一部処理はPython2向けになっているようで3では動かない箇所がある。適宜修正して動かした。

導入

必要になりそうなライブラリをわさわさ入れる。下記は一例。TensorflowはCPU版を指定しているので、必要があればGPU版に書き換える。

name: fruits-detection

channels:
  - defaults
  - anaconda
  - menpo

dependencies:
  - python==3.7
  - numpy==1.16.4
  - pandas==0.24.2
  - scikit-learn==0.20.3
  - tensorflow==1.13.1
  - Keras==2.2.4
  - Pillow==6.0.0
  - imageio==2.5.0
  - lxml==4.3.3
  - matplotlib
  - pip:
    - mlflow==1.0.0
    - python-dateutil==2.8.0
    - pytz==2019.1
    - Click==7.0
    - retrying==1.3.3
    - scipy==1.3.0
    - opencv-python==4.1.0.25
    - Cython==0.29.11
    - contextlib2==0.5.5
    - jupyter
    - tqdm

condaで上記yamlのパスを指定してenvを生成。

$ conda env create -f {yamlのパス}
$ conda activate fruits-detection

続いて下記のInstallationの項を実行してセットアップを完了させる。

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/installation.md

datasetの用意

OIDのデータ取得については下記を参照。

https://blog.imind.jp/entry/2019/06/18/210510

落としてきたデータを学習や推測に用いる際はTFRecord file formatに直す必要がある。

TensorFlow Modelsのresearch配下の object_detection/dataset_tools/create_oid_tf_record.py がTFRecordファイルの生成処理。

ファイルのパス構成は下記のようになることがレコメンドされているので、これに合わせてディレクトリを作成する。label_map_fileとpipeline_config_fileはファイルでそれ以外はディレクトリ。

├── data
│   ├── eval
│   ├── label_map_file
│   ├── test
│   └── train
└── models
    └── fruits_model
        ├── eval
        ├── pipeline_config_file
        ├── test
        └── train

この構成通りのディレクトリ構成にすると、学習時にValidationのデータ等を読み込んでくれる。

次にOIDのtrain, test, validationそれぞれのannotations-bbox.csvから必要なラベルだけ抜き出す。

for type in train test validation
do
    head -1 ${type}-annotations-bbox.csv > fruits-${type}-annotations-bbox.csv
    grep "/m/014j1m" ${type}-annotations-bbox.csv >> fruits-${type}-annotations-bbox.csv
    grep "/m/0cyhj_" ${type}-annotations-bbox.csv >> fruits-${type}-annotations-bbox.csv
    grep "/m/07fbm7" ${type}-annotations-bbox.csv >> fruits-${type}-annotations-bbox.csv
    grep "/m/09qck" ${type}-annotations-bbox.csv >> fruits-${type}-annotations-bbox.csv
done

これで3つのファイルができる。ファイルの行数はそれぞれのイメージ数+1(ヘッダの分)になる。

$ wc -l fruits-*

   2307 fruits-test-annotations-bbox.csv
  19650 fruits-train-annotations-bbox.csv
    661 fruits-validation-annotations-bbox.csv

各ファイルをcloneした modelsプロジェクトの research/data 配下に置いておく。

├── data
│   ├── eval
│   ├── fruits-test-annotations-bbox.csv
│   ├── fruits-train-annotations-bbox.csv
│   ├── fruits-validation-annotations-bbox.csv
│   ├── label_map_file
│   ├── test
│   └── train
└── models
    └── fruits_model
        ├── eval
        ├── pipeline_config_file
        ├── test
        └── train

画像ファイルについても必要なファイルだけコピーして1つのディレクトリ配下に入れておく。

import os
import shutil
import pandas as pd

data_types = {
    'train': 'train', 'validation': 'eval', 'test': 'test'
}

for src_type, dest_type in data_types.items():
    df = pd.read_csv('data/fruits-%s-annotations-bbox.csv' % src_type)

    src_dir = '{OpenImageDatasetV5の画像を置いたパス/%s' % src_type
    dest_dir = 'data/%s' % dest_type

    for idx, row in df.iterrows():
        image_name = '{}.jpg'.format(row.ImageID)
        src_path = os.path.join(src_dir, image_name)
        dest_path = os.path.join(dest_dir, image_name)
        shutil.copyfile(src_path, dest_path)

label_mapを記述する。書き方は下記ファイルが参考になる。

  • object_detection/data/oid_v4_label_map.pbtxt

ファイル名は仮にoid_v5_fruits_label_map.pbtxtとしておく。

item {
  name: "/m/014j1m"
  id: 1
  display_name: "Apple"
}
item {
  name: "/m/0cyhj_"
  id: 2
  display_name: "Orange"
}
item {
  name: "/m/07fbm7"
  id: 3
  display_name: "Strawberry"
}
item {
  name: "/m/09qck"
  id: 4
  display_name: "Banana"
}

生成したlabel_map_fileはdataフォルダ配下に置く。

├── data
│   ├── eval
│   ├── fruits-test-annotations-bbox.csv
│   ├── fruits-train-annotations-bbox.csv
│   ├── fruits-validation-annotations-bbox.csv
│   ├── oid_v5_fruits_label_map.pbtxt
以下略

TFRecordファイルに変換

annotations-bbox.csv, label_map, 画像ファイルが用意できたら、create_oid_tf_record.py(Open Image Dataset用のTFRecord変換機能)で学習用のデータをTFRecordファイルに変換する。

train, test, validationそれぞれのTFRecordファイルを生成。

# trainは数が多いのでnum_shards=10くらいで
mkdir -p models/fruits_model/train
python object_detection/dataset_tools/create_oid_tf_record.py \
    --input_box_annotations_csv data/fruits-train-annotations-bbox.csv \
    --input_images_directory data/train \
    --input_label_map data/oid_v5_fruits_label_map.pbtxt \
    --output_tf_record_path_prefix models/fruits_model/train/oid_fruits \
    --num_shards 10

# validationはnum_shards=1で
mkdir -p models/fruits_model/eval
python object_detection/dataset_tools/create_oid_tf_record.py \
    --input_box_annotations_csv data/fruits-validation-annotations-bbox.csv \
    --input_images_directory data/eval \
    --input_label_map data/oid_v5_fruits_label_map.pbtxt \
    --output_tf_record_path_prefix models/fruits_model/eval/oid_fruits \
    --num_shards 1

# testはnum_shards=1で
mkdir -p models/fruits_model/test
python object_detection/dataset_tools/create_oid_tf_record.py \
    --input_box_annotations_csv data/fruits-test-annotations-bbox.csv \
    --input_images_directory data/test \
    --input_label_map data/oid_v5_fruits_label_map.pbtxt \
    --output_tf_record_path_prefix models/fruits_model/test/oid_fruits \
    --num_shards 1

これでmodelsディレクトリ配下にnum_shardsに指定した数で分割されたTFRecordファイルが生成される。(エラーが出る場合があるけど、それについては後述)

$ ls models/fruits_model/train

oid_fruits-00000-of-00010  oid_fruits-00002-of-00010  oid_fruits-00004-of-00010  oid_fruits-00006-of-00010  oid_fruits-00008-of-00010
oid_fruits-00001-of-00010  oid_fruits-00003-of-00010  oid_fruits-00005-of-00010  oid_fruits-00007-of-00010  oid_fruits-00009-of-00010

引数説明。

num-shardsはファイルを分割する数。デフォルトでは100になっているが、今回のデータはサンプリングしてあるのでtestデータだと5くらいで十分。100MB以下くらいにすると良いらしい。

output_tf_record_path_prefix で出力されるファイルのprefixを指定する。

今回は models/(train|eval|test)/oid_fruits を指定しているので、出力されるファイル名は models/(train|eval|test)/oid_fruits-{number}-of-{total_number} になる。

create_oid_tf_record.pyでエラーが出る場合

うちの環境ではこれを実行すると encoded_image = image_file.read() のところで下記のようなエラーが出た。

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xff in position 0: invalid start byte

いくつか修正する。

# object_detection/dataset_tools/create_oid_tf_record.py : 106行目

      with tf.gfile.Open(image_path) as image_file:

          ↓

      with tf.gfile.Open(image_path, 'rb') as image_file:
# object_detection/utils/dataset_util.py : 29行目

def bytes_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))

        ↓

def bytes_feature(value):
  if isinstance(value, str):
    value = value.encode()
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
# object_detection/utils/dataset_util.py : 33行目

def bytes_list_feature(value):
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

        ↓

def bytes_list_feature(value):
  value = [v.encode() if isinstance(v, str) else v for v in value]
  return tf.train.Feature(bytes_list=tf.train.BytesList(value=value))

train

続いてSSDの学習用のconfigを作る。

samplesからやりたい処理に似たconfigをコピーする。

$ cp object_detection/samples/configs/ssd_inception_v3_pets.config  models/fruits_model/ssd_inception_v3_fruits.config

コピーしたファイルを今回のデータに合わせて編集する。

# classの数はorange, apple, banana, strawberryの4つ
num_classes: 4
# 再開用の設定であるfine_tune_checkpointはコメントアウト
# fine_tune_checkpoint: "my_models/fruits_model.ckpt"
# 入力のTFRecordとlabel_mapのpathを指定する
train_input_reader: {
  tf_record_input_reader {
    input_path: "models/fruits_model/train/oid_fruits-*"
  }
  label_map_path: "data/oid_v5_fruits_label_map.pbtxt"
}
# oid_challenge_detection_metricsを使う
# 枚数はvalidationに入っている660枚。
eval_config: {
  metrics_set: "oid_challenge_detection_metrics"
  num_examples: 660
}
# evaluate用のファイルの設定
eval_input_reader: {
  tf_record_input_reader {
    input_path: "models/fruits_model/eval/oid_fruits-*"
  }
  label_map_path: "data/oid_v5_fruits_label_map.pbtxt"
  shuffle: false
  num_readers: 1
}

object_detection/utils/object_detection_evaluation.py で unicode関数が呼ばれていてPython3ではエラーになる。

当該ファイルの上の方で下記を宣言して回避する。

unicode = str

これで準備完了。

trainを実行する。初回はさらっと終わるようにNUM_TRAIN_STEPS=3000と少なめの数を指定しておく。

実行時間はGTX1070tiで30分位。

PIPELINE_CONFIG_PATH=models/fruits_model/ssd_inception_v3_fruits.config
MODEL_DIR=models/fruits_model
NUM_TRAIN_STEPS=3000
SAMPLE_1_OF_N_EVAL_EXAMPLES=1
python object_detection/model_main.py \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH} \
    --model_dir=${MODEL_DIR} \
    --num_train_steps=${NUM_TRAIN_STEPS} \
    --sample_1_of_n_eval_examples=$SAMPLE_1_OF_N_EVAL_EXAMPLES \
    --alsologtostderr

実行状況は下記で確認できる。

tensorboard --logdir=models/fruits_model

各ラベルごとのAPの推移等もちゃんと用意されている。3000回くらいだとまだほとんど当たらない。

f:id:mwsoft:20190720203914p:plain

上部タブのIMAGESを選択する実際にvalidationデータの判定結果が見れる。

f:id:mwsoft:20190720204034p:plain

eval_input_reader の shuffle を Trueにしておくとこの画面に表示する画像を毎回shuffleして出してくれる。

checkpointを読み込んでtrainする

models配下に生成されたcheckpointから学習を再開してみる。

先ほどの実行でckpt-3000というファイルが出来ているはず。

$ ls models/fruits_model/model.ckpt-3000*

models/fruits_model/model.ckpt-3000.data-00000-of-00001
models/fruits_model/model.ckpt-3000.meta
models/fruits_model/model.ckpt-3000.index

models/fruits_model/ssd_inception_v3_fruits.config を編集して fine_tune_checkpoint に再開に利用するチェックポイントを指定する。

  fine_tune_checkpoint: "models/fruits_model/model.ckpt-3000"

NUM_TRAIN_STEPSの数を上げてtrainを実行する。

PIPELINE_CONFIG_PATH=models/fruits_model/ssd_inception_v3_fruits.config
MODEL_DIR=models/fruits_model
NUM_TRAIN_STEPS=10000
SAMPLE_1_OF_N_EVAL_EXAMPLES=1
python object_detection/model_main.py \
    --pipeline_config_path=${PIPELINE_CONFIG_PATH} \
    --model_dir=${MODEL_DIR} \
    --num_train_steps=${NUM_TRAIN_STEPS} \
    --sample_1_of_n_eval_examples=$SAMPLE_1_OF_N_EVAL_EXAMPLES \
    --alsologtostderr

学習したモデルを使ってobject detectionの実行

学習が終わったらテストデータに対してobject detectionを実行してみる。

最後のcheckpointからfrozenなモデルの生成。

python object_detection/export_inference_graph.py \
    --input_type image_tensor \
    --pipeline_config_path models/fruits_model/ssd_inception_v3_fruits.config \
    --trained_checkpoint_prefix models/fruits_model/model.ckpt-100000 \
    --output_directory fine_tuned_model

これで --output_directory に指定したパスにモデルが出力される。

TF_RECORD_FILESにはカンマ区切りでファイルが指定できる。下記では1ファイルのみ指定。

mkdir result

PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python
TF_RECORD_FILES=models/fruits_model/test/oid_fruits-00000-of-00001
INFER_GRAPH=fine_tuned_model/frozen_inference_graph.pb
python object_detection/inference/infer_detections.py \
  --input_tfrecord_paths ${TF_RECORD_FILES} \
  --output_tfrecord_path result/fruits_test.tfrecord \
  --inference_graph ${INFER_GRAPH} \
  --discard_image_pixels

--output_tfrecord_path に指定したパスに結果が出力される。

出力されたtf_recordファイルを読み込む。

ファイルのフォーマットは core/standard_fields.py の DetectionResultFields クラスを参照。

num_detections, detection_boxes, detection_boxesを見ておけば良さそう。

import sys
import os

import tensorflow as tf
from object_detection.data_decoders import tf_example_decoder

tf.enable_eager_execution()

# tfrecordのformatはTfExampleDecoderから取得する
decoder = tf_example_decoder.TfExampleDecoder()
features = decoder.keys_to_features

# ファイルを読み込む
dataset = tf.data.TFRecordDataset('fruits_test.tfrecord')

# パース
_parser = lambda row: tf.parse_single_example(row, features)
parsed_dataset = dataset.map(_parser)

# 結果表示
for row in parsed_dataset:
    pass

print(row)
    #=> {'image/class/label': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a82ff208>,
    #=>  'image/class/text': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a82ff2e8>,
    #=>  'image/object/area': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8c01978>,
    #=>  'image/object/bbox/xmax': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8c23320>,
    #=>  'image/object/bbox/xmin': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8c23e80>,
    #=>  'image/object/bbox/ymax': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8c23550>,
    #=>  'image/object/bbox/ymin': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8c23048>,
    #=>  'image/object/class/label': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8c23cc0>,
    #=>  'image/object/class/text': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8c23f28>,
    #=>  'image/object/difficult': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8c23710>,
    #=>  'image/object/group_of': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8c23518>,
    #=>  'image/object/is_crowd': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9f37adda0>,
    #=>  'image/object/weight': <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fc9a8ba8748>,
    #=>  'image/encoded': <tf.Tensor: id=17224, shape=(), dtype=string, numpy=b''>,
    #=>  'image/filename': <tf.Tensor: id=17225, shape=(), dtype=string, numpy=b'feb546bb21cec831.jpg'>,
    #=>  'image/format': <tf.Tensor: id=17226, shape=(), dtype=string, numpy=b'jpeg'>,
    #=>  'image/height': <tf.Tensor: id=17227, shape=(), dtype=int64, numpy=1>,
    #=>  'image/key/sha256': <tf.Tensor: id=17228, shape=(), dtype=string, numpy=b''>,
    #=>  'image/source_id': <tf.Tensor: id=17240, shape=(), dtype=string, numpy=b'feb546bb21cec831'>,
    #=>  'image/width': <tf.Tensor: id=17241, shape=(), dtype=int64, numpy=1>}

テストデータの評価

下記ページの手順に従って評価を実行する。

https://github.com/tensorflow/models/blob/master/research/object_detection/g3doc/oid_inference_and_evaluation.md

2箇所ほどエラーが出たのでコードを修正している。

# object_detection/metrics/offline_eval_map_corloc.py : 162行目

input_config = configs['eval_input_config']

  ↓

input_config = configs['eval_input_configs'][0]
# object_detection/metrics/tf_example_parser.py : 46行目

  def parse(self, tf_example):
    return "".join(str(tf_example.features.feature[self.field_name]
                   .bytes_list.value)) if tf_example.features.feature[
                       self.field_name].HasField("bytes_list") else Nonedef parse(self, tf_example):
    return "".join(tf_example.features.feature[self.field_name]
                   .bytes_list.value) if tf_example.features.feature[
                       self.field_name].HasField("bytes_list") else None

必要なconfigを出力して offline_eval_map_corloc.py を実行する。

mkdir tmp

echo "
label_map_path: 'data/oid_v5_fruits_label_map.pbtxt'
tf_record_input_reader: { input_path: 'result/fruits_test.tfrecord' }
" > tmp/input_config.pbtxt

echo "
metrics_set: 'oid_V2_detection_metrics'
" > tmp/eval_config.pbtxt

python object_detection/metrics/offline_eval_map_corloc.py \
  --eval_dir=tmp/eval \
  --eval_config_path=tmp/eval_config.pbtxt \
  --input_config_path=tmp/input_config.pbtxt

実行結果は --eval_dir で指定したパスに出力される。

$ cat tmp/eval/metrics.csv

OpenImagesV2_PerformanceByCategory/AP@0.5IOU/b'Apple',0.47982193735198647
OpenImagesV2_PerformanceByCategory/AP@0.5IOU/b'Orange',0.1849362140783137
OpenImagesV2_PerformanceByCategory/AP@0.5IOU/b'Strawberry',0.2862974195468154
OpenImagesV2_PerformanceByCategory/AP@0.5IOU/b'Banana',0.3910937216356761

上記はNUM_TRAIN_STEPSを20万まで増やして実行したモデルだけど、だいぶ低い。

一番良いリンゴで約48%、オレンジは18%。

テーブルの上に1つフルーツが置いてあるようなシチュエーションであればこのモデルでも十分使えるけど、果物の周囲がごちゃっとしてる画像だとかなり判定をミスる。

単一画像の処理

画像を読み込み、当該モデルを利用して物体検知結果を出力する方法は下記に記載されている。

https://github.com/tensorflow/models/blob/master/research/object_detection/object_detection_tutorial.ipynb

これを参考に、単純に1画像のみを読み込んで実行するサンプルコードを記載した。下記参照。

https://github.com/imind-inc/blog/blob/master/notebook/tf_models_object_detection/simple_image_detection.ipynb

PATH_TO_FROZEN_GRAPH を今回生成したモデルのパスに変更すれば実行できるはず。

精度を上げる

デフォルトのまま回しただけではあまり精度が出なかった。

confの項目の中から下記などをチューニングしていくと5%くらいはさらっと精度が向上した。

  • data augmentation
  • 利用する画像の選別
  • learning_rateを調整する

頑張ればもう少し上げられそう。

data augmentation

data augmentationの設定例。

pipe line configファイル(上の例では models/fruits_model/ssd_inception_v3_fruits.config)を見てみると、下記のような設定がされている。

  data_augmentation_options {
    random_horizontal_flip {
    }
  }
  data_augmentation_options {
    ssd_random_crop {
    }
  }

data_augmentation_optionsに指定できるパラメータは、 protos/preprocessor_pb2.py に載っている。

下記は回転や明度彩度などの調整を加えた設定。

  data_augmentation_options {
    random_horizontal_flip {
    }
  }
  data_augmentation_options {
    random_vertical_flip {
    }
  }
  data_augmentation_options {
    random_rotation90 {
    }
  }
  data_augmentation_options {
    ssd_random_crop {
    }
  }
  data_augmentation_options {
    random_rgb_to_gray {
      probability=0.1
    }
  }
  data_augmentation_options {
    random_adjust_brightness {
      max_delta=0.2
    }
  }
  data_augmentation_options {
    random_adjust_contrast {
      min_delta=0.8,
      max_delta=1.25
    }
  }
  data_augmentation_options {
    random_adjust_hue {
    }
  }
  data_augmentation_options {
    random_adjust_saturation {
    }
  }

各種data augmentationの動きについては下記を参照。

https://blog.imind.jp/entry/2019/07/20/132143

改定履歴

Author: Masato Watanabe, Date: 2019-07-29, 記事投稿