def read_and_decode_batch(filename_queue, batch_size, capacity, min_after_dequeue):
"""Dequeue a batch of data from the TFRecord.
Args:
filename_queue: Filename Queue of the TFRecord.
batch_size: How many records dequeued each time.
capacity: The capacity of the queue.
min_after_dequeue: Ensures a minimum amount of shuffling of examples.
Returns:
List of the dequeued (batch_label, batch_ids, batch_values).
"""
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
batch_serialized_example = tf.train.shuffle_batch([serialized_example],
batch_size=batch_size, capacity=capacity, min_after_dequeue=min_after_dequeue)
# The feature definition here should BE consistent with LibSVM TO TFRecord process.
features = tf.parse_example(batch_serialized_example,
features={
"label": tf.FixedLenFeature([], tf.float32),
"ids": tf.VarLenFeature(tf.int64),
"values": tf.VarLenFeature(tf.float32)
})
batch_label = features["label"]
batch_ids = features["ids"]
batch_values = features["values"]
return batch_label, batch_ids, batch_values
评论列表
文章目录