def weights_init(self,module):
for m in module.modules():
if isinstance(m, nn.Conv2d):
init.xavier_uniform(m.weight, gain=np.sqrt(2))
init.constant(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
评论列表
文章目录