def __init__(self, num_features, max_len, eps=1e-5, momentum=0.1, affine=True):
super(recurrent_BatchNorm, self).__init__()
self.num_features = num_features
self.affine = affine
self.max_len = max_len
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = nn.Parameter(torch.Tensor(num_features))
self.register_parameter('weight', self.weight)
self.bias = nn.Parameter(torch.Tensor(num_features))
self.register_parameter('bias', self.bias)
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
for i in xrange(max_len):
self.register_buffer('running_mean_{}'.format(i), torch.zeros(num_features))
self.register_buffer('running_var_{}'.format(i), torch.ones(num_features))
self.reset_parameters()
recurrent_BatchNorm.py 文件源码
python
阅读 31
收藏 0
点赞 0
评论 0
评论列表
文章目录