expert_utils.py 文件源码

python
阅读 76 收藏 0 点赞 0 评论 0

项目:tensor2tensor 作者: tensorflow 项目源码 文件源码
def remove(self, x):
    """Remove padding from the given tensor.

    Args:
      x (tf.Tensor): of shape [dim_origin,...]

    Returns:
      a tensor of shape [dim_compressed,...] with dim_compressed <= dim_origin
    """
    with tf.name_scope("pad_reduce/remove"):
      x_shape = x.get_shape().as_list()
      x = tf.gather_nd(
          x,
          indices=self.nonpad_ids,
      )
      if not context.in_eager_mode():
        # This is a hack but for some reason, gather_nd return a tensor of
        # undefined shape, so the shape is set up manually
        x.set_shape([None] + x_shape[1:])
    return x
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号