def parse_mnist_tfrec(tfrecord, features_shape):
tfrecord_features = tf.parse_single_example(
tfrecord,
features={
'features': tf.FixedLenFeature([], tf.string),
'targets': tf.FixedLenFeature([], tf.string)
}
)
features = tf.decode_raw(tfrecord_features['features'], tf.uint8)
features = tf.reshape(features, features_shape)
features = tf.cast(features, tf.float32)
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
targets = tf.reshape(targets, [])
targets = tf.one_hot(indices=targets, depth=10, on_value=1, off_value=0)
targets = tf.cast(targets, tf.float32)
return features, targets
评论列表
文章目录