def weights_init(m):
# classname = m.__class__.__name__
if isinstance(m, nn.Conv2d):
#print('init conv2d')
#init.xavier_uniform(m.weight.data, gain=np.sqrt(2.0))
init.kaiming_uniform(m.weight.data, mode='fan_in')
# m.weight.data.normal_(0.0, 0.02)
if isinstance(m, nn.Linear):
#print('init fc')
init.kaiming_uniform(m.weight.data, mode='fan_in')
# size = m.weight.size()
# fan_out = size[0] # number of rows
# fan_in = size[1] # number of columns
# variance = np.sqrt(2.0/(fan_in + fan_out))
# m.weight.data.uniform_(0.0, variance)
评论列表
文章目录