signals.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:nengo_dl 作者: nengo 项目源码 文件源码
def _scatter_f_var(self, dst, src, mode="update"):
        # create a temporary variable for dst so that we can use the sparse
        # variable updates. despite this looking incredibly inefficient, it is
        # actually faster than the scatter_nd approach
        # from tensorflow.python.ops import gen_state_ops
        # var = gen_state_ops._temporary_variable(
        #     self.bases[dst.key].get_shape(), self.bases[dst.key].dtype)
        # var_name = var.op.name
        # var = tf.assign(var, self.bases[dst.key])

        var = self.bases[dst.key]

        if (dst.as_slice is not None and
                var.get_shape().is_compatible_with(src.get_shape()) and
                dst.indices[0] == 0 and
                dst.indices[-1] == var.get_shape()[0].value - 1 and
                len(dst.indices) == var.get_shape()[0]):
            if mode == "inc":
                result = tf.assign_add(var, src, use_locking=False)
            else:
                result = tf.assign(var, src, use_locking=False)
        elif mode == "inc":
            result = tf.scatter_add(var, dst.tf_indices, src,
                                    use_locking=False)
        else:
            result = tf.scatter_update(var, dst.tf_indices, src,
                                       use_locking=False)

        # result = gen_state_ops._destroy_temporary_variable(var, var_name)

        return result
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号