network_vgg16.py 文件源码

python
阅读 33 收藏 0 点赞 0 评论 0

项目:HandDetection 作者: YunqiuXu 项目源码 文件源码
def batch_norm_layer(self, to_be_normalized, is_training):
    if is_training:
      train_phase = tf.constant(1)
    else:
      train_phase = tf.constant(-1)
    beta = tf.Variable(tf.constant(0.0, shape=[to_be_normalized.shape[-1]]), name='beta', trainable=True)
    gamma = tf.Variable(tf.constant(1.0, shape=[to_be_normalized.shape[-1]]), name='gamma', trainable=True)
    # axises = np.arange(len(to_be_normalized.shape) - 1) # change to apply tensorflow 1.3
    axises = [0,1,2]

    print("start nn.moments")
    print("axises : " + str(axises))
    batch_mean, batch_var = tf.nn.moments(to_be_normalized, axises, name='moments')
    print("nn.moments successful")
    ema = tf.train.ExponentialMovingAverage(decay=0.5)

    def mean_var_with_update():
        ema_apply_op = ema.apply([batch_mean, batch_var])
        with tf.control_dependencies([ema_apply_op]):
            return tf.identity(batch_mean), tf.identity(batch_var)

    mean, var = tf.cond(train_phase > 0, mean_var_with_update, lambda: (ema.average(batch_mean), ema.average(batch_var))) # if is training --> update
    normed = tf.nn.batch_normalization(to_be_normalized, mean, var, beta, gamma, 1e-3)
    return normed
评论列表
文章目录


问题


面经


文章

微信
公众号

扫码关注公众号