def batchnorm(bottom, is_train, num_reference, epsilon=1e-3, decay=0.999, name=None):
""" virtual batch normalization (poor man's version)
the first half is the true batch, the second half is the reference batch.
When num_reference = 0, it is just typical batch normalization.
To use virtual batch normalization in test phase, "update_popmean.py" needed to be executed first
(in order to store the mean and variance of the reference batch into pop_mean and pop_variance of batchnorm.)
"""
batch_size = bottom.get_shape().as_list()[0]
inst_size = batch_size - num_reference
instance_weight = np.ones([batch_size])
if inst_size > 0:
reference_weight = 1.0 - (1.0 / ( num_reference + 1.0))
instance_weight[0:inst_size] = 1.0 - reference_weight
instance_weight[inst_size:] = reference_weight
else:
decay = 0.0
return slim.batch_norm(bottom, activation_fn=None, is_training=is_train, decay=decay, scale=True, scope=name, batch_weights=instance_weight)
评论列表
文章目录