def undo_make_batch_of_event_sample_matrices(
self, x, sample_shape, name="undo_make_batch_of_event_sample_matrices"):
"""Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E.
Where:
- `B_ = B if B else [1]`,
- `E_ = E if E else [1]`,
- `S_ = [tf.reduce_prod(S)]`.
This function "reverses" `make_batch_of_event_sample_matrices`.
Args:
x: `Tensor` of shape `B_+E_+S_`.
sample_shape: `Tensor` (1D, `int32`).
name: `String`. The name to give this op.
Returns:
x: `Tensor`. Input transposed/reshaped to `S+B+E`.
"""
with self._name_scope(name, values=[x, sample_shape]):
x = ops.convert_to_tensor(x, name="x")
sample_shape = ops.convert_to_tensor(sample_shape, name="sample_shape")
x = distribution_util.rotate_transpose(x, shift=1)
if self._is_all_constant_helper(self.batch_ndims, self.event_ndims):
if self._batch_ndims_is_0 or self._event_ndims_is_0:
b = ((min(-2, -1 - self._event_ndims_static),)
if self._batch_ndims_is_0 else ())
e = (-1,) if self._event_ndims_is_0 else ()
x = array_ops.squeeze(x, squeeze_dims=b + e)
_, batch_shape, event_shape = self.get_shape(x)
else:
s = (x.get_shape().as_list() if x.get_shape().is_fully_defined()
else array_ops.shape(x))
batch_shape = array_ops.slice(s, (1,), (self.batch_ndims,))
# Since sample_dims=1 and is left-most, we add 1 to the number of
# batch_ndims to get the event start dim.
event_start = math_ops.select(
self._batch_ndims_is_0, 2, 1 + self.batch_ndims)
event_shape = array_ops.slice(s, (event_start,), (self.event_ndims,))
new_shape = array_ops.concat(0, (sample_shape, batch_shape, event_shape))
x = array_ops.reshape(x, shape=new_shape)
return x
评论列表
文章目录