def _stitch_mat_from_vecs(vector_list):
""" Stitches a given list of vectors into a 3x3 matrix.
Input:
vector_list: list of 9 tensors, which will be stitched into a matrix. list contains matrix elements
in a row-first fashion (m11, m12, m13, m21, m22, m23, m31, m32, m33). Length of the vectors has
to be the same, because it is interpreted as batch dimension.
"""
assert len(vector_list) == 9, "There have to be exactly 9 tensors in vector_list."
batch_size = vector_list[0].get_shape().as_list()[0]
vector_list = [tf.reshape(x, [1, batch_size]) for x in vector_list]
trafo_matrix = tf.dynamic_stitch([[0], [1], [2],
[3], [4], [5],
[6], [7], [8]], vector_list)
trafo_matrix = tf.reshape(trafo_matrix, [3, 3, batch_size])
trafo_matrix = tf.transpose(trafo_matrix, [2, 0, 1])
return trafo_matrix
评论列表
文章目录