learn_comb.py 文件源码

python
阅读 30 收藏 0 点赞 0 评论 0

项目:DMNN 作者: magnux 项目源码 文件源码
def learn_comb_orth_rmsprop(poses, dm_shape, reuse=None, _float_type=tf.float32):
    with tf.variable_scope("learn_comb", reuse=reuse):
        comb_matrix = tf.get_variable(
            "matrix", [dm_shape[0], dm_shape[1]],
            initializer=identity_initializer(0),
            dtype=_float_type, trainable=False
        )
        comb_matrix_m = tf.get_variable(
            "matrix_momentum", [dm_shape[0], dm_shape[1]],
            initializer=tf.zeros_initializer(),
            dtype=_float_type, trainable=False
        )
        tf.add_to_collection(COMB_MATRIX_COLLECTION, comb_matrix)
        poses = tf.tensordot(poses, comb_matrix, [[2], [1]])
        poses = tf.transpose(poses, [0, 1, 3, 2])

        # Special update code
        def update_comb_mat(grad, lr):
            I = tf.constant(np.eye(dm_shape[0]), dtype=_float_type)

            # Momentum update
            momentum_op = tf.assign(comb_matrix_m,
                                    comb_matrix_m * 0.99 + (1 - 0.99) * tf.square(grad))

            with tf.control_dependencies([momentum_op]):
                # Matrix update
                scaled_grad = lr * grad / tf.sqrt(comb_matrix_m + 1.e-5)
                A = tf.matmul(tf.transpose(scaled_grad), comb_matrix) - \
                    tf.matmul(tf.transpose(comb_matrix), scaled_grad)
                t1 = I + 0.5 * A
                t2 = I - 0.5 * A
                Y = tf.matmul(tf.matmul(tf.matrix_inverse(t1), t2), comb_matrix)
                return tf.assign(comb_matrix, Y)

        # Visualization
        cb_min = tf.reduce_min(comb_matrix)
        cb_max = tf.reduce_max(comb_matrix)
        comb_matrix_image = (comb_matrix - cb_min) / (cb_max - cb_min) * 255.0
        comb_matrix_image = tf.cast(comb_matrix_image, tf.uint8)
        comb_matrix_image = tf.reshape(comb_matrix_image, [1, dm_shape[0], dm_shape[1], 1])

        return poses, comb_matrix_image, update_comb_mat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号