distribution.py 文件源码

python
阅读 25 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def _expand_sample_shape(self, sample_shape):
    """Helper to `sample` which ensures sample_shape is 1D."""
    sample_shape_static_val = tensor_util.constant_value(sample_shape)
    ndims = sample_shape.get_shape().ndims
    if sample_shape_static_val is None:
      if ndims is None or not sample_shape.get_shape().is_fully_defined():
        ndims = array_ops.rank(sample_shape)
      expanded_shape = distribution_util.pick_vector(
          math_ops.equal(ndims, 0),
          np.array((1,), dtype=dtypes.int32.as_numpy_dtype()),
          array_ops.shape(sample_shape))
      sample_shape = array_ops.reshape(sample_shape, expanded_shape)
      total = math_ops.reduce_prod(sample_shape)  # reduce_prod([]) == 1
    else:
      if ndims is None:
        raise ValueError(
            "Shouldn't be here; ndims cannot be none when we have a "
            "tf.constant shape.")
      if ndims == 0:
        sample_shape_static_val = np.reshape(sample_shape_static_val, [1])
        sample_shape = ops.convert_to_tensor(
            sample_shape_static_val,
            dtype=dtypes.int32,
            name="sample_shape")
      total = np.prod(sample_shape_static_val,
                      dtype=dtypes.int32.as_numpy_dtype())
    return sample_shape, total
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号