def create_input_fn(pipeline,
batch_size,
bucket_boundaries=None,
allow_smaller_final_batch=False,
scope=None):
"""Creates an input function that can be used with tf.learn estimators.
Note that you must pass "factory funcitons" for both the data provider and
featurizer to ensure that everything will be created in the same graph.
Args:
pipeline: An instance of `seq2seq.data.InputPipeline`.
batch_size: Create batches of this size. A queue to hold a
reasonable number of batches in memory is created.
bucket_boundaries: int list, increasing non-negative numbers.
If None, no bucket is performed.
Returns:
An input function that returns `(feature_batch, labels_batch)`
tuples when called.
"""
def input_fn():
"""Creates features and labels.
"""
with tf.variable_scope(scope or "input_fn"):
data_provider = pipeline.make_data_provider()
features_and_labels = pipeline.read_from_data_provider(data_provider)
if bucket_boundaries:
_, batch = tf.contrib.training.bucket_by_sequence_length(
input_length=features_and_labels["source_len"],
bucket_boundaries=bucket_boundaries,
tensors=features_and_labels,
batch_size=batch_size,
keep_input=features_and_labels["source_len"] >= 1,
dynamic_pad=True,
capacity=5000 + 16 * batch_size,
allow_smaller_final_batch=allow_smaller_final_batch,
name="bucket_queue")
else:
batch = tf.train.batch(
tensors=features_and_labels,
enqueue_many=False,
batch_size=batch_size,
dynamic_pad=True,
capacity=5000 + 16 * batch_size,
allow_smaller_final_batch=allow_smaller_final_batch,
name="batch_queue")
# Separate features and labels
features_batch = {k: batch[k] for k in pipeline.feature_keys}
if set(batch.keys()).intersection(pipeline.label_keys):
labels_batch = {k: batch[k] for k in pipeline.label_keys}
else:
labels_batch = None
return features_batch, labels_batch
return input_fn
评论列表
文章目录