def __init__(self, num_features, max_length, eps=1e-5, momentum=0.1,
affine=True):
"""
Most parts are copied from
torch.nn.modules.batchnorm._BatchNorm.
"""
super(SeparatedBatchNorm1d, self).__init__()
self.num_features = num_features
self.max_length = max_length
self.affine = affine
self.eps = eps
self.momentum = momentum
if self.affine:
self.weight = nn.Parameter(torch.FloatTensor(num_features))
self.bias = nn.Parameter(torch.FloatTensor(num_features))
else:
self.register_parameter('weight', None)
self.register_parameter('bias', None)
for i in range(max_length):
self.register_buffer(
'running_mean_{}'.format(i), torch.zeros(num_features))
self.register_buffer(
'running_var_{}'.format(i), torch.ones(num_features))
self.reset_parameters()
评论列表
文章目录