概要
TensorFlowでTFRecordの読み書きをする。
バージョン情報
- tensorflow==1.13.1
TFRecordとは
機械学習をする際に学習データがメモリに載せきれるようなサイズでない場合は、ストレージに置いておいて頻繁に読み込む必要がある。
TFRecordはその手の処理を行う際に適したファイルフォーマット。Protocol Buffersベース。
TFRecordファイルの出力
まずはサンプルデータの用意。
年齢(int)、性別(bool), 名前(bytes)を持つデータがあったとする。
import tensorflow as tf import pandas as pd df = pd.DataFrame([ { 'age': 20, 'gender': True, 'name': b'higashi'}, { 'age': 30, 'gender': False, 'name': b'minami'}, { 'age': 40, 'gender': True, 'name': b'nishi'} ]) df.head() #=> age gender name #=> 0 20 True b'higashi' #=> 1 30 False b'minami' #=> 2 40 True b'nishi'
nameはbytesで持っておく。stringだと下記のようなエラーが出たりする。
TypeError: 'higashi' has type str, but expected one of: bytes
この情報をtf.train.Featureに変換する関数を書く。
def to_feature_map(age, gender, name): return { 'age': tf.train.Feature( int64_list=tf.train.Int64List(value=[age])), 'gender': tf.train.Feature( int64_list=tf.train.Int64List(value=[gender])), 'name': tf.train.Feature( bytes_list=tf.train.BytesList(value=[name])) } print( to_feature_map(10, True, b'hoge') ) #=> {'age': int64_list { #=> value: 10 #=> } #=> , 'gender': int64_list { #=> value: 1 #=> } #=> , 'name': bytes_list { #=> value: "hoge" #=> }}
各レコードをtf.Exampleに変換して members.tfrecordというファイル名で保存してみる。
with tf.python_io.TFRecordWriter('members.tfrecord') as f: for idx, row in df.iterrows(): feature = to_feature_map( row['age'], row['gender'], row['name']) features = tf.train.Features(feature=feature) example = tf.train.Example(features=features) f.write(example.SerializeToString())
members.tfrecord というファイルが出力される。
TFRecordファイルの読込み
Datasetを使って出力したファイルを読み込む。
事前に eager_execution を有効にしておく必要がある。
import tensorflow as tf tf.enable_eager_execution()
ファイルを読込み、構造を指定してパースします。
# ファイルを読み込む dataset = tf.data.TFRecordDataset('members.tfrecord') # featureの構造を指定 feature_map = { 'age': tf.FixedLenFeature([], tf.int64), 'gender': tf.FixedLenFeature([], tf.int64), 'name': tf.FixedLenFeature([], tf.string) } # パース _parser = lambda row: tf.parse_single_example(row, feature_map) parsed_dataset = dataset.map(_parser) # 結果表示 for row in parsed_dataset: print(row) #=> {'age': <tf.Tensor: id=26, shape=(), dtype=int64, numpy=20>, 'gender': <tf.Tensor: id=27, shape=(), dtype=int64, numpy=1>, 'name': <tf.Tensor: id=28, shape=(), dtype=string, numpy=b'higashi'>} #=> {'age': <tf.Tensor: id=32, shape=(), dtype=int64, numpy=30>, 'gender': <tf.Tensor: id=33, shape=(), dtype=int64, numpy=0>, 'name': <tf.Tensor: id=34, shape=(), dtype=string, numpy=b'minami'>} #=> {'age': <tf.Tensor: id=38, shape=(), dtype=int64, numpy=40>, 'gender': <tf.Tensor: id=39, shape=(), dtype=int64, numpy=1>, 'name': <tf.Tensor: id=40, shape=(), dtype=string, numpy=b'nishi'>}
圧縮してTFRecordファイルを保存
TFRecordはZLIBかGZIP圧縮が用意されている。LZOやSnappyなどはいないらしい。
Tensorflowで重いデータを扱う場合は通常圧縮済みの画像ファイルなので、圧縮に対するモチベーションが高まる機会は少ないかもしれない。
下記はGZIP指定でWriterを呼び出す例。
writer = tf.python_io.TFRecordWriter('members.tfrecord',
options=tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.GZIP))
改定履歴
Author: Masato Watanabe, Date: 2019-07-12, 記事投稿