rand.py 文件源码

python
阅读 34 收藏 0 点赞 0 评论 0

项目:odin 作者: imito 项目源码 文件源码
def orthogonal(shape, gain=1.0):
  if gain == 'relu':
    gain = np.sqrt(2)

  if len(shape) < 2:
    raise RuntimeError("Only shapes of length 2 or more are supported, but "
                       "given shape:%s" % str(shape))

  flat_shape = (shape[0], np.prod(shape[1:]))
  a = get_rng().normal(0.0, 1.0, flat_shape)
  u, _, v = np.linalg.svd(a, full_matrices=False)
  # pick the one with the correct shape
  q = u if u.shape == flat_shape else v
  q = q.reshape(shape)
  return np.cast[floatX](gain * q)


# ===========================================================================
# Fast initialization
# ===========================================================================
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号