def _expand_sample_shape(self, sample_shape):
"""Helper to `sample` which ensures sample_shape is 1D."""
sample_shape_static_val = tensor_util.constant_value(sample_shape)
ndims = sample_shape.get_shape().ndims
if sample_shape_static_val is None:
if ndims is None or not sample_shape.get_shape().is_fully_defined():
ndims = array_ops.rank(sample_shape)
expanded_shape = distribution_util.pick_vector(
math_ops.equal(ndims, 0),
np.array((1,), dtype=dtypes.int32.as_numpy_dtype()),
array_ops.shape(sample_shape))
sample_shape = array_ops.reshape(sample_shape, expanded_shape)
total = math_ops.reduce_prod(sample_shape) # reduce_prod([]) == 1
else:
if ndims is None:
raise ValueError(
"Shouldn't be here; ndims cannot be none when we have a "
"tf.constant shape.")
if ndims == 0:
sample_shape_static_val = np.reshape(sample_shape_static_val, [1])
sample_shape = ops.convert_to_tensor(
sample_shape_static_val,
dtype=dtypes.int32,
name="sample_shape")
total = np.prod(sample_shape_static_val,
dtype=dtypes.int32.as_numpy_dtype())
return sample_shape, total
评论列表
文章目录