ops.py 文件源码

python
阅读 26 收藏 0 点赞 0 评论 0

项目:LiTeFlow 作者: petrux 项目源码 文件源码
def trim(tensor, width):
    """Trim the tensor on the -1 axis.

    Trim a given tensor of shape `[..., in_width]` to a smaller tensor
    of shape `[..., width]`, along the -1 axis. If the `width` argument
    is greater or equal than the actual width of the tensor, no operation
    is performed.

    Arguments:
      tensor: a 3D tf.Tensor of shape `[..., in_width]`.
      width: a `int` representing the target value of the 3rd
        dimension of the output tensor.

    Returns:
      a 3D tensor of shape `[..., width]` where the
        third dimension is the minimum between the input width
        and the value of the `width` argument.

    Example:
    ```python
    # t is a tensor like:
    # [[[1, 1, 1],
        [2, 2, 2],
        [3, 3, 3]],
        [7, 7, 7],
        [8, 8, 8],
        [9, 9, 9]]]

    q = trim(t, 2)

    # q is a tensor like:
    # [[[1, 1],
        [2, 2],
        [3, 3]],
        [7, 7],
        [8, 8],
        [9, 9]]]
"""
result = tf.cond(
    tf.less_equal(tf.shape(tensor)[-1], width),
    lambda: tensor,
    lambda: _trim(tensor, width))
result.set_shape(tensor.get_shape().as_list()[:-1] + [width])
return result

```

评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号