def normal(tensor, mean=0, std=1):
"""Fills the input Tensor or Variable with values drawn from the normal distribution :math:`N(mean, std)`.
Args:
tensor: an n-dimensional torch.Tensor or autograd.Variable
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
Examples:
>>> w = torch.Tensor(3, 5)
>>> nn.init.normal(w)
"""
if isinstance(tensor, Variable):
normal(tensor.data, mean=mean, std=std)
return tensor
return tensor.normal_(mean, std)
评论列表
文章目录