def _create_tfrecord_dataset(tmpdir):
if not gfile.Exists(tmpdir):
gfile.MakeDirs(tmpdir)
data_sources = test_utils.create_tfrecord_files(tmpdir, num_files=1)
keys_to_features = {
'image/encoded': tf.FixedLenFeature(shape=(), dtype=dtypes.string, default_value=''),
'image/format': tf.FixedLenFeature(shape=(), dtype=dtypes.string, default_value='jpeg'),
'image/class/label': tf.FixedLenFeature(
shape=[1], dtype=dtypes.int64,
default_value=array_ops.zeros([1], dtype=dtypes.int64))
}
items_to_handlers = {
'image': tfslim.tfexample_decoder.Image(),
'label': tfslim.tfexample_decoder.Tensor('image/class/label'),
}
decoder = TFExampleDecoder(keys_to_features, items_to_handlers)
return Dataset(
data_sources=data_sources, reader=tf.TFRecordReader, decoder=decoder, num_samples=100)
评论列表
文章目录