def orthogonal(shape, gain=1.0):
if gain == 'relu':
gain = np.sqrt(2)
if len(shape) < 2:
raise RuntimeError("Only shapes of length 2 or more are supported, but "
"given shape:%s" % str(shape))
flat_shape = (shape[0], np.prod(shape[1:]))
a = get_rng().normal(0.0, 1.0, flat_shape)
u, _, v = np.linalg.svd(a, full_matrices=False)
# pick the one with the correct shape
q = u if u.shape == flat_shape else v
q = q.reshape(shape)
return np.cast[floatX](gain * q)
# ===========================================================================
# Fast initialization
# ===========================================================================
评论列表
文章目录