iMind Developers Blog

iMind開発者ブログ

TFRecordの読み書き

概要

TensorFlowでTFRecordの読み書きをする。

バージョン情報

  • tensorflow==1.13.1

TFRecordとは

機械学習をする際に学習データがメモリに載せきれるようなサイズでない場合は、ストレージに置いておいて頻繁に読み込む必要がある。

TFRecordはその手の処理を行う際に適したファイルフォーマット。Protocol Buffersベース。

TFRecordファイルの出力

まずはサンプルデータの用意。

年齢(int)、性別(bool), 名前(bytes)を持つデータがあったとする。

import tensorflow as tf
import pandas as pd

df = pd.DataFrame([
    { 'age': 20, 'gender': True, 'name': b'higashi'},
    { 'age': 30, 'gender':  False, 'name': b'minami'},
    { 'age': 40, 'gender': True, 'name': b'nishi'}
])

df.head() 
    #=>    age  gender        name
    #=> 0   20    True  b'higashi'
    #=> 1   30   False   b'minami'
    #=> 2   40    True    b'nishi'

nameはbytesで持っておく。stringだと下記のようなエラーが出たりする。

TypeError: 'higashi' has type str, but expected one of: bytes

この情報をtf.train.Featureに変換する関数を書く。

def to_feature_map(age, gender, name):
    return {
        'age': tf.train.Feature(
            int64_list=tf.train.Int64List(value=[age])),
        'gender': tf.train.Feature(
            int64_list=tf.train.Int64List(value=[gender])),
        'name': tf.train.Feature(
            bytes_list=tf.train.BytesList(value=[name]))
    }

print( to_feature_map(10, True, b'hoge') )
    #=> {'age': int64_list {
    #=>   value: 10
    #=> }
    #=> , 'gender': int64_list {
    #=>   value: 1
    #=> }
    #=> , 'name': bytes_list {
    #=>   value: "hoge"
    #=> }}

各レコードをtf.Exampleに変換して members.tfrecordというファイル名で保存してみる。

with tf.python_io.TFRecordWriter('members.tfrecord') as f:
    for idx, row in df.iterrows():
        feature = to_feature_map(
            row['age'], row['gender'], row['name'])
        features = tf.train.Features(feature=feature)
        example = tf.train.Example(features=features)
        f.write(example.SerializeToString())

members.tfrecord というファイルが出力される。

TFRecordファイルの読込み

Datasetを使って出力したファイルを読み込む。

事前に eager_execution を有効にしておく必要がある。

import tensorflow as tf
tf.enable_eager_execution()

ファイルを読込み、構造を指定してパースします。

# ファイルを読み込む
dataset = tf.data.TFRecordDataset('members.tfrecord')

# featureの構造を指定
feature_map = {
    'age': tf.FixedLenFeature([], tf.int64),
    'gender': tf.FixedLenFeature([], tf.int64),
    'name': tf.FixedLenFeature([], tf.string)
}


# パース
_parser = lambda row: tf.parse_single_example(row, feature_map)
parsed_dataset = dataset.map(_parser)

# 結果表示
for row in parsed_dataset:
    print(row)
    #=> {'age': <tf.Tensor: id=26, shape=(), dtype=int64, numpy=20>, 'gender': <tf.Tensor: id=27, shape=(), dtype=int64, numpy=1>, 'name': <tf.Tensor: id=28, shape=(), dtype=string, numpy=b'higashi'>}
    #=> {'age': <tf.Tensor: id=32, shape=(), dtype=int64, numpy=30>, 'gender': <tf.Tensor: id=33, shape=(), dtype=int64, numpy=0>, 'name': <tf.Tensor: id=34, shape=(), dtype=string, numpy=b'minami'>}
    #=> {'age': <tf.Tensor: id=38, shape=(), dtype=int64, numpy=40>, 'gender': <tf.Tensor: id=39, shape=(), dtype=int64, numpy=1>, 'name': <tf.Tensor: id=40, shape=(), dtype=string, numpy=b'nishi'>}

圧縮してTFRecordファイルを保存

TFRecordはZLIBかGZIP圧縮が用意されている。LZOやSnappyなどはいないらしい。

Tensorflowで重いデータを扱う場合は通常圧縮済みの画像ファイルなので、圧縮に対するモチベーションが高まる機会は少ないかもしれない。

下記はGZIP指定でWriterを呼び出す例。

writer = tf.python_io.TFRecordWriter('members.tfrecord',
    options=tf.python_io.TFRecordOptions(
        tf.python_io.TFRecordCompressionType.GZIP))

改定履歴

Author: Masato Watanabe, Date: 2019-07-12, 記事投稿