def _batch_normalization(self, x, layer_name, eps=0.001):
with tf.variable_scope(layer_name.split('/')[-1]):
beta, gamma, mean, variance = self._get_batch_normalization_weights(layer_name)
# beta, gamma, mean, variance are numpy arrays!!!
if beta is None:
try:
net = tf.layers.batch_normalization(x, epsilon = eps)
except:
net = tf.nn.batch_normalization(x, 0, 1, 0, 1, 0.01)
else:
try:
net = tf.layers.batch_normalization(x, epsilon = eps,
beta_initializer = tf.constant_initializer(value=beta,dtype=tf.float32),
gamma_initializer = tf.constant_initializer(value=gamma,dtype=tf.float32),
moving_mean_initializer = tf.constant_initializer(value=mean,dtype=tf.float32),
moving_variance_initializer = tf.constant_initializer(value=variance,dtype=tf.float32),
)
except:
net = tf.nn.batch_normalization(x, mean, variance, beta, gamma, 0.01)
mean = '%s/batch_normalization/moving_mean:0'%(layer_name)
variance = '%s/batch_normalization/moving_variance:0'%(layer_name)
try:
tf.add_to_collection(tf.GraphKeys.SAVE_TENSORS, tf.get_default_graph().get_tensor_by_name(mean))
tf.add_to_collection(tf.GraphKeys.SAVE_TENSORS, tf.get_default_graph().get_tensor_by_name(variance))
except:
pass
return net
评论列表
文章目录