def input_pipeline(file_pattern, mode, capacity=64):
keys_to_features = {
"inputs": tf.VarLenFeature(tf.int64),
"targets": tf.VarLenFeature(tf.int64)
}
items_to_handlers = {
"inputs": tfexample_decoder.Tensor("inputs"),
"targets": tfexample_decoder.Tensor("targets")
}
# Now the non-trivial case construction.
with tf.name_scope("examples_queue"):
training = (mode == "train")
# Read serialized examples using slim parallel_reader.
num_epochs = None if training else 1
data_files = parallel_reader.get_data_files(file_pattern)
num_readers = min(4 if training else 1, len(data_files))
_, examples = parallel_reader.parallel_read([file_pattern],
tf.TFRecordReader,
num_epochs=num_epochs,
shuffle=training,
capacity=2 * capacity,
min_after_dequeue=capacity,
num_readers=num_readers)
decoder = tfexample_decoder.TFExampleDecoder(keys_to_features,
items_to_handlers)
decoded = decoder.decode(examples, items=list(items_to_handlers))
examples = {}
for (field, tensor) in zip(keys_to_features, decoded):
examples[field] = tensor
# We do not want int64s as they do are not supported on GPUs.
return {k: tf.to_int32(v) for (k, v) in six.iteritems(examples)}
评论列表
文章目录