tensorflow_backend.py 文件源码

python
阅读 29 收藏 0 点赞 0 评论 0

项目:keras-rcnn 作者: broadinstitute 项目源码 文件源码
def scatter_add_tensor(ref, indices, updates, name=None):
    """
    Adds sparse updates to a variable reference.

    This operation outputs ref after the update is done. This makes it
    easier to chain operations that need to use the reset value.

    Duplicate indices: if multiple indices reference the same location,
    their contributions add.

    Requires updates.shape = indices.shape + ref.shape[1:].

    :param ref: A Tensor. Must be one of the following types: float32,
    float64, int64, int32, uint8, uint16, int16, int8, complex64, complex128,
    qint8, quint8, qint32, half.

    :param indices: A Tensor. Must be one of the following types: int32,
    int64. A tensor of indices into the first dimension of ref.

    :param updates: A Tensor. Must have the same dtype as ref. A tensor of
    updated values to add to ref

    :param name: A name for the operation (optional).

    :return: Same as ref. Returned as a convenience for operations that want
    to use the updated values after the update is done.
    """
    with tensorflow.name_scope(name, 'scatter_add_tensor', [ref, indices, updates]) as scope:
        ref = tensorflow.convert_to_tensor(ref, name='ref')

        indices = tensorflow.convert_to_tensor(indices, name='indices')

        updates = tensorflow.convert_to_tensor(updates, name='updates')

        ref_shape = tensorflow.shape(ref, out_type=indices.dtype, name='ref_shape')

        scattered_updates = tensorflow.scatter_nd(indices, updates, ref_shape, name='scattered_updates')

        with tensorflow.control_dependencies([tensorflow.assert_equal(ref_shape, tensorflow.shape(scattered_updates, out_type=indices.dtype))]):
            output = tensorflow.add(ref, scattered_updates, name=scope)

        return output
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号