def _stitch_mat_from_vecs(vector_list):
""" Stitches a given list of vectors into a 4x4 matrix.
Input:
vector_list: list of 16 tensors, which will be stitched into a matrix. list contains matrix elements
in a row-first fashion (m11, m12, m13, m14, m21, m22, m23, m24, ...). Length of the vectors has
to be the same, because it is interpreted as batch dimension.
"""
assert len(vector_list) == 16, "There have to be exactly 16 tensors in vector_list."
batch_size = vector_list[0].get_shape().as_list()[0]
vector_list = [tf.reshape(x, [1, batch_size]) for x in vector_list]
trafo_matrix = tf.dynamic_stitch([[0], [1], [2], [3],
[4], [5], [6], [7],
[8], [9], [10], [11],
[12], [13], [14], [15]], vector_list)
trafo_matrix = tf.reshape(trafo_matrix, [4, 4, batch_size])
trafo_matrix = tf.transpose(trafo_matrix, [2, 0, 1])
return trafo_matrix
评论列表
文章目录