shape.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:lsdc 作者: febert 项目源码 文件源码
def undo_make_batch_of_event_sample_matrices(
      self, x, sample_shape, name="undo_make_batch_of_event_sample_matrices"):
    """Reshapes/transposes `Distribution` `Tensor` from B_+E_+S_ to S+B+E.

    Where:
      - `B_ = B if B else [1]`,
      - `E_ = E if E else [1]`,
      - `S_ = [tf.reduce_prod(S)]`.

    This function "reverses" `make_batch_of_event_sample_matrices`.

    Args:
      x: `Tensor` of shape `B_+E_+S_`.
      sample_shape: `Tensor` (1D, `int32`).
      name: `String`. The name to give this op.

    Returns:
      x: `Tensor`. Input transposed/reshaped to `S+B+E`.
    """
    with self._name_scope(name, values=[x, sample_shape]):
      x = ops.convert_to_tensor(x, name="x")
      sample_shape = ops.convert_to_tensor(sample_shape, name="sample_shape")
      x = distribution_util.rotate_transpose(x, shift=1)
      if self._is_all_constant_helper(self.batch_ndims, self.event_ndims):
        if self._batch_ndims_is_0 or self._event_ndims_is_0:
          b = ((min(-2, -1 - self._event_ndims_static),)
               if self._batch_ndims_is_0 else ())
          e = (-1,) if self._event_ndims_is_0 else ()
          x = array_ops.squeeze(x, squeeze_dims=b + e)
        _, batch_shape, event_shape = self.get_shape(x)
      else:
        s = (x.get_shape().as_list() if x.get_shape().is_fully_defined()
             else array_ops.shape(x))
        batch_shape = array_ops.slice(s, (1,), (self.batch_ndims,))
        # Since sample_dims=1 and is left-most, we add 1 to the number of
        # batch_ndims to get the event start dim.
        event_start = math_ops.select(
            self._batch_ndims_is_0, 2, 1 + self.batch_ndims)
        event_shape = array_ops.slice(s, (event_start,), (self.event_ndims,))
      new_shape = array_ops.concat(0, (sample_shape, batch_shape, event_shape))
      x = array_ops.reshape(x, shape=new_shape)
      return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号