def orthogonal_initializer():
"""Return an orthogonal initializer.
Random orthogonal matrix is byproduct of singular value decomposition
applied on a matrix initialized with normal distribution.
The initializer works with 2D square matrices and matrices that can be
splitted along axis 1 to several 2D matrices. In the latter case, each
submatrix is initialized independently and the resulting orthogonal
matrices are concatenated along axis 1.
Note this is a higher order function in order to mimic the tensorflow
initializer API.
"""
# pylint: disable=unused-argument
def func(shape, dtype, partition_info=None):
if len(shape) != 2:
raise ValueError(
"Orthogonal initializer only works with 2D matrices.")
if shape[1] % shape[0] != 0:
raise ValueError("Shape {} is not compatible with orthogonal "
"initializer.".format(str(shape)))
mult = int(shape[1] / shape[0])
dim = shape[0]
orthogonals = []
for _ in range(mult):
matrix = tf.random_normal([dim, dim], dtype=dtype)
orthogonals.append(tf.svd(matrix)[1])
return tf.concat(orthogonals, 1)
# pylint: enable=unused-argument
return func
# pylint: disable=too-few-public-methods
评论列表
文章目录