def bucket_by_sequence_length(self, dataset, example_length_fn, bucket_boundaries,
bucket_batch_sizes, window_size):
"""Bucket entries in dataset by length.
Args:
dataset: Dataset of dict<feature name, Tensor>.
example_length_fn: function from example to int, determines the length of
the example, which will determine the bucket it goes into.
bucket_boundaries: list<int>, boundaries of the buckets.
bucket_batch_sizes: list<int>, batch size per bucket.
window_size: an integer divisible by all elements of bucket_batch_sizes
Returns:
Dataset of padded and batched examples.
"""
with tf.name_scope("bucket_by_seq_length"):
def example_to_bucket_id(example):
"""Return int64 id of the length bucket for this example."""
seq_length = example_length_fn(example)
boundaries = list(bucket_boundaries)
buckets_min = [np.iinfo(np.int32).min] + boundaries
buckets_max = boundaries + [np.iinfo(np.int32).max]
conditions_c = tf.logical_and(
tf.less_equal(buckets_min, seq_length),
tf.less(seq_length, buckets_max))
bucket_id = tf.reduce_min(tf.where(conditions_c))
return bucket_id
def batching_fn(bucket_id, grouped_dataset):
batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
batch_size = batch_sizes[bucket_id]
# Pad each dimension of each feature so that they match.
padded_shapes = dict(
[(name, [None] * len(shape))
for name, shape in grouped_dataset.output_shapes.items()])
return grouped_dataset.padded_batch(batch_size, padded_shapes)
dataset = dataset.group_by_window(example_to_bucket_id, batching_fn,
window_size)
return dataset
评论列表
文章目录