def _layer_BatchNorm(self):
self.add_body(0, """
@staticmethod
def __batch_normalization(dim, name, **kwargs):
if dim == 1: layer = nn.BatchNorm1d(**kwargs)
elif dim == 2: layer = nn.BatchNorm2d(**kwargs)
elif dim == 3: layer = nn.BatchNorm3d(**kwargs)
else: raise NotImplementedError()
if 'scale' in __weights_dict[name]:
layer.state_dict()['weight'].copy_(torch.from_numpy(__weights_dict[name]['scale']))
else:
layer.weight.data.fill_(1)
if 'bias' in __weights_dict[name]:
layer.state_dict()['bias'].copy_(torch.from_numpy(__weights_dict[name]['bias']))
else:
layer.bias.data.fill_(0)
layer.state_dict()['running_mean'].copy_(torch.from_numpy(__weights_dict[name]['mean']))
layer.state_dict()['running_var'].copy_(torch.from_numpy(__weights_dict[name]['var']))
return layer""")
评论列表
文章目录