textdataflow.py 文件源码

python
阅读 20 收藏 0 点赞 0 评论 0

项目:tefla 作者: openAGI 项目源码 文件源码
def bucket_by_sequence_length(self, dataset, example_length_fn, bucket_boundaries,
                                  bucket_batch_sizes, window_size):
        """Bucket entries in dataset by length.

        Args:
          dataset: Dataset of dict<feature name, Tensor>.
          example_length_fn: function from example to int, determines the length of
            the example, which will determine the bucket it goes into.
          bucket_boundaries: list<int>, boundaries of the buckets.
          bucket_batch_sizes: list<int>, batch size per bucket.
          window_size: an integer divisible by all elements of bucket_batch_sizes

        Returns:
          Dataset of padded and batched examples.
        """
        with tf.name_scope("bucket_by_seq_length"):

            def example_to_bucket_id(example):
                """Return int64 id of the length bucket for this example."""
                seq_length = example_length_fn(example)

                boundaries = list(bucket_boundaries)
                buckets_min = [np.iinfo(np.int32).min] + boundaries
                buckets_max = boundaries + [np.iinfo(np.int32).max]
                conditions_c = tf.logical_and(
                    tf.less_equal(buckets_min, seq_length),
                    tf.less(seq_length, buckets_max))
                bucket_id = tf.reduce_min(tf.where(conditions_c))

                return bucket_id

            def batching_fn(bucket_id, grouped_dataset):
                batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
                batch_size = batch_sizes[bucket_id]

                # Pad each dimension of each feature so that they match.
                padded_shapes = dict(
                    [(name, [None] * len(shape))
                     for name, shape in grouped_dataset.output_shapes.items()])
                return grouped_dataset.padded_batch(batch_size, padded_shapes)

            dataset = dataset.group_by_window(example_to_bucket_id, batching_fn,
                                              window_size)
            return dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号