def glorot_normal(shape, gain=1.0, c01b=False):
orig_shape = shape
if c01b:
if len(shape) != 4:
raise RuntimeError(
"If c01b is True, only shapes of length 4 are accepted")
n1, n2 = shape[0], shape[3]
receptive_field_size = shape[1] * shape[2]
else:
if len(shape) < 2:
shape = (1,) + tuple(shape)
n1, n2 = shape[:2]
receptive_field_size = np.prod(shape[2:])
std = gain * np.sqrt(2.0 / ((n1 + n2) * receptive_field_size))
return np.cast[floatX](
get_rng().normal(0.0, std, size=orig_shape))
评论列表
文章目录