def bucket_by_sequence_length(dataset,
example_length_fn,
bucket_boundaries,
bucket_batch_sizes,
padded_shapes=None):
"""Bucket entries in dataset by length.
Args:
dataset: Dataset of dict<feature name, Tensor>.
example_length_fn: function from example to int, determines the length of
the example, which will determine the bucket it goes into.
bucket_boundaries: list<int>, boundaries of the buckets.
bucket_batch_sizes: list<int>, batch size per bucket.
padded_shapes: dict<feature name, list<int>>, optional, shapes of the
features with None where feature should be padded to max in that dim.
Returns:
Dataset of padded and batched examples.
"""
with tf.name_scope("bucket_by_seq_length"):
def example_to_bucket_id(example):
"""Return int64 id of the length bucket for this example."""
seq_length = example_length_fn(example)
boundaries = list(bucket_boundaries)
buckets_min = [np.iinfo(np.int32).min] + boundaries
buckets_max = boundaries + [np.iinfo(np.int32).max]
conditions_c = tf.logical_and(
tf.less_equal(buckets_min, seq_length),
tf.less(seq_length, buckets_max))
bucket_id = tf.reduce_min(tf.where(conditions_c))
return bucket_id
def window_size_fn(bucket_id):
# window size = batch size
batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
window_size = batch_sizes[bucket_id]
return window_size
def batching_fn(bucket_id, grouped_dataset):
batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
batch_size = batch_sizes[bucket_id]
return padded_batch(grouped_dataset, batch_size, padded_shapes)
dataset = dataset.apply(
tf.contrib.data.group_by_window(example_to_bucket_id, batching_fn, None,
window_size_fn))
return dataset
评论列表
文章目录