def vec_to_tri(vectors, N):
"""
Takes a D x M tensor `vectors' and maps it to a D x matrix_size X matrix_sizetensor
where the where the lower triangle of each matrix_size x matrix_size matrix is
constructed by unpacking each M-vector.
Native TensorFlow version of Custom Op by Mark van der Wilk.
def int_shape(x):
return list(map(int, x.get_shape()))
D, M = int_shape(vectors)
N = int( np.floor( 0.5 * np.sqrt( M * 8. + 1. ) - 0.5 ) )
# Check M is a valid triangle number
assert((matrix * (N + 1)) == (2 * M))
"""
indices = list(zip(*np.tril_indices(N)))
indices = tf.constant([list(i) for i in indices], dtype=tf.int64)
def vec_to_tri_vector(vector):
return tf.scatter_nd(indices=indices, shape=[N, N], updates=vector)
return tf.map_fn(vec_to_tri_vector, vectors)
评论列表
文章目录