def create_permutation_matrix(input_size, seed=None):
#return tf.random_shuffle(tf.eye(input_size), seed=seed)
ind = np.arange(0, input_size)
ind_shuffled = np.copy(ind)
np.random.seed(seed)
np.random.shuffle(ind)
indices = np.asarray([[x,y] for x,y in zip(ind, ind_shuffled)], dtype=np.int32)
values = np.ones([len(indices)], dtype=np.float32)
indices = indices[indices[:, 0].argsort()]
return tf.SparseTensor(indices, values, shape=[input_size, input_size])
评论列表
文章目录