def random_orthonormal_initializer(shape, dtype=tf.float32, partition_info=None):
"""Variable initializer that produces a random orthonormal matrix
Args:
shape: shape of the variable
Returns:
random_orthogonal_matrix for initialization.
"""
if len(shape) != 2 or shape[0] != shape[1]:
raise ValueError("Expecting square shape, got %s" % shape)
_, u, _ = tf.svd(tf.random_normal(shape, dtype=dtype), full_matrices=True)
return u
评论列表
文章目录