def U_weight_init(ms):
for m in ms.modules():
classname = m.__class__.__name__
if classname.find('Conv2d') != -1:
m.weight.data = init.kaiming_normal(m.weight.data, a=0.2)
elif classname.find('ConvTranspose2d') != -1:
m.weight.data = init.kaiming_normal(m.weight.data)
print ('worked!') # TODO: kill this
elif classname.find('BatchNorm') != -1:
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
elif classname.find('Linear') != -1:
m.weight.data = init.kaiming_normal(m.weight.data)
评论列表
文章目录