learn_comb.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:DMNN 作者: magnux 项目源码 文件源码
def learn_comb_orth(poses, dm_shape, reuse=None, _float_type=tf.float32):
    with tf.variable_scope("learn_comb", reuse=reuse):
        comb_matrix = tf.get_variable(
            "matrix", [dm_shape[0], dm_shape[1]],
            initializer=identity_initializer(0),
            dtype=_float_type, trainable=False
        )
        tf.add_to_collection(COMB_MATRIX_COLLECTION, comb_matrix)
        poses = tf.tensordot(poses, comb_matrix, [[2], [1]])
        poses = tf.transpose(poses, [0, 1, 3, 2])

        # Special update code
        def update_comb_mat(grad, lr):
            A = tf.matmul(tf.transpose(grad), comb_matrix) - \
                tf.matmul(tf.transpose(comb_matrix), grad)
            I = tf.constant(np.eye(dm_shape[0]), dtype=_float_type)
            t1 = I + lr / 2 * A
            t2 = I - lr / 2 * A
            Y = tf.matmul(tf.matmul(tf.matrix_inverse(t1), t2), comb_matrix)
            return tf.assign(comb_matrix, Y)

        # Visualization
        cb_min = tf.reduce_min(comb_matrix)
        cb_max = tf.reduce_max(comb_matrix)
        comb_matrix_image = (comb_matrix - cb_min) / (cb_max - cb_min) * 255.0
        comb_matrix_image = tf.cast(comb_matrix_image, tf.uint8)
        comb_matrix_image = tf.reshape(comb_matrix_image, [1, dm_shape[0], dm_shape[1], 1])

        return poses, comb_matrix_image, update_comb_mat
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号