def parse_mnist_tfrec(tfrecord, name, features_shape, scalar_targs=False):
tfrecord_features = tf.parse_single_example(
tfrecord,
features={
'features': tf.FixedLenFeature([], tf.string),
'targets': tf.FixedLenFeature([], tf.string)
},
name=name+'_data'
)
with tf.variable_scope('features'):
features = tf.decode_raw(
tfrecord_features['features'], tf.uint8
)
features = tf.reshape(features, features_shape)
features = tf.cast(features, tf.float32)
with tf.variable_scope('targets'):
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
if scalar_targs:
targets = tf.reshape(targets, [])
targets = tf.one_hot(
indices=targets, depth=10, on_value=1, off_value=0
)
targets = tf.cast(targets, tf.float32)
return features, targets
评论列表
文章目录