rand.py 文件源码

python
阅读 28 收藏 0 点赞 0 评论 0

项目:odin 作者: imito 项目源码 文件源码
def glorot_normal(shape, gain=1.0, c01b=False):
  orig_shape = shape
  if c01b:
    if len(shape) != 4:
      raise RuntimeError(
          "If c01b is True, only shapes of length 4 are accepted")
    n1, n2 = shape[0], shape[3]
    receptive_field_size = shape[1] * shape[2]
  else:
    if len(shape) < 2:
      shape = (1,) + tuple(shape)
    n1, n2 = shape[:2]
    receptive_field_size = np.prod(shape[2:])

  std = gain * np.sqrt(2.0 / ((n1 + n2) * receptive_field_size))
  return np.cast[floatX](
      get_rng().normal(0.0, std, size=orig_shape))
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号