fista_tf.py 文件源码

python
阅读 37 收藏 0 点赞 0 评论 0

项目:AdaptiveOptim 作者: tomMoral 项目源码 文件源码
def _get_step(self, inputs):
        Z, Y, X, theta, lmbd = self.inputs
        K, p = self.D.shape
        L = self.L
        with tf.name_scope("ISTA_iteration"):
            self.S = tf.constant(np.eye(K, dtype=np.float32) - self.S0/L,
                                 shape=[K, K], name='S')
            self.We = tf.constant(self.D.T/L, shape=[p, K],
                                  dtype=tf.float32, name='We')
            hk = tf.matmul(Y, self.S) + tf.matmul(X, self.We)
            self.step_FISTA = Zk = soft_thresholding(hk, lmbd/L)
            # self.theta_k = tk = (tf.sqrt(theta*theta+4) - theta)*theta/2
            self.theta_k = tk = (1 + tf.sqrt(1 + 4*theta*theta))/2
            dZ = tf.subtract(Zk, Z)
            # self.Yk = Zk + tk*(1/theta-1)*dZ
            self.Yk = Zk + (theta-1)/tk*dZ
            self.dz = tf.reduce_mean(tf.reduce_sum(
                dZ*dZ, reduction_indices=[1]))

            step = tf.tuple([Zk, tk, self.Yk])
        return step, self.dz
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号