canonical_trafo.py 文件源码

python
阅读 31 收藏 0 点赞 0 评论 0

项目:hand3d 作者: lmb-freiburg 项目源码 文件源码
def _stitch_mat_from_vecs(vector_list):
    """ Stitches a given list of vectors into a 3x3 matrix.

        Input:
            vector_list: list of 9 tensors, which will be stitched into a matrix. list contains matrix elements
                in a row-first fashion (m11, m12, m13, m21, m22, m23, m31, m32, m33). Length of the vectors has
                to be the same, because it is interpreted as batch dimension.
    """

    assert len(vector_list) == 9, "There have to be exactly 9 tensors in vector_list."
    batch_size = vector_list[0].get_shape().as_list()[0]
    vector_list = [tf.reshape(x, [1, batch_size]) for x in vector_list]

    trafo_matrix = tf.dynamic_stitch([[0], [1], [2],
                                      [3], [4], [5],
                                      [6], [7], [8]], vector_list)

    trafo_matrix = tf.reshape(trafo_matrix, [3, 3, batch_size])
    trafo_matrix = tf.transpose(trafo_matrix, [2, 0, 1])

    return trafo_matrix
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号