data_reader.py 文件源码

python
阅读 38 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def bucket_by_sequence_length(dataset,
                              example_length_fn,
                              bucket_boundaries,
                              bucket_batch_sizes,
                              padded_shapes=None):
  """Bucket entries in dataset by length.

  Args:
    dataset: Dataset of dict<feature name, Tensor>.
    example_length_fn: function from example to int, determines the length of
      the example, which will determine the bucket it goes into.
    bucket_boundaries: list<int>, boundaries of the buckets.
    bucket_batch_sizes: list<int>, batch size per bucket.
    padded_shapes: dict<feature name, list<int>>, optional, shapes of the
      features with None where feature should be padded to max in that dim.

  Returns:
    Dataset of padded and batched examples.
  """
  with tf.name_scope("bucket_by_seq_length"):

    def example_to_bucket_id(example):
      """Return int64 id of the length bucket for this example."""
      seq_length = example_length_fn(example)

      boundaries = list(bucket_boundaries)
      buckets_min = [np.iinfo(np.int32).min] + boundaries
      buckets_max = boundaries + [np.iinfo(np.int32).max]
      conditions_c = tf.logical_and(
          tf.less_equal(buckets_min, seq_length),
          tf.less(seq_length, buckets_max))
      bucket_id = tf.reduce_min(tf.where(conditions_c))

      return bucket_id

    def window_size_fn(bucket_id):
      # window size = batch size
      batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
      window_size = batch_sizes[bucket_id]
      return window_size

    def batching_fn(bucket_id, grouped_dataset):
      batch_sizes = tf.constant(bucket_batch_sizes, dtype=tf.int64)
      batch_size = batch_sizes[bucket_id]
      return padded_batch(grouped_dataset, batch_size, padded_shapes)

    dataset = dataset.apply(
        tf.contrib.data.group_by_window(example_to_bucket_id, batching_fn, None,
                                        window_size_fn))
    return dataset
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号